feat: Enhance OpenAI compatibility and introduce unified LLM service data models
- Refactored model fetching logic to support various authentication methods for OpenAI-compatible APIs. - Introduced new data models for LLM service requests and responses to standardize interactions across layers. - Added an adapter base class for unified request execution across different providers. - Implemented utility functions for building OpenAI-compatible client configurations and request overrides.
This commit is contained in:
259
src/llm_models/model_client/adapter_base.py
Normal file
259
src/llm_models/model_client/adapter_base.py
Normal file
@@ -0,0 +1,259 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Coroutine, Generic, Tuple, TypeVar, cast
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.config.model_configs import ModelInfo
|
||||
|
||||
from .base_client import (
|
||||
APIResponse,
|
||||
AudioTranscriptionRequest,
|
||||
BaseClient,
|
||||
EmbeddingRequest,
|
||||
ResponseRequest,
|
||||
UsageRecord,
|
||||
UsageTuple,
|
||||
)
|
||||
|
||||
RawStreamT = TypeVar("RawStreamT")
|
||||
"""流式原始响应类型变量。"""
|
||||
|
||||
RawResponseT = TypeVar("RawResponseT")
|
||||
"""非流式原始响应类型变量。"""
|
||||
|
||||
TaskResultT = TypeVar("TaskResultT")
|
||||
"""异步任务返回值类型变量。"""
|
||||
|
||||
ProviderStreamResponseHandler = Callable[
|
||||
[RawStreamT, asyncio.Event | None],
|
||||
Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]],
|
||||
]
|
||||
"""Provider 专用流式响应处理函数类型。"""
|
||||
|
||||
ProviderResponseParser = Callable[[RawResponseT], Tuple[APIResponse, UsageTuple | None]]
|
||||
"""Provider 专用非流式响应解析函数类型。"""
|
||||
|
||||
|
||||
async def await_task_with_interrupt(
|
||||
task: asyncio.Task[TaskResultT],
|
||||
interrupt_flag: asyncio.Event | None,
|
||||
*,
|
||||
interval_seconds: float = 0.1,
|
||||
) -> TaskResultT:
|
||||
"""在支持外部中断的前提下等待异步任务完成。
|
||||
|
||||
Args:
|
||||
task: 待等待的异步任务。
|
||||
interrupt_flag: 外部中断标记。
|
||||
interval_seconds: 轮询检查间隔,单位秒。
|
||||
|
||||
Returns:
|
||||
TaskResultT: 任务执行结果。
|
||||
|
||||
Raises:
|
||||
ReqAbortException: 等待期间收到外部中断信号时抛出。
|
||||
"""
|
||||
from src.llm_models.exceptions import ReqAbortException
|
||||
|
||||
while not task.done():
|
||||
if interrupt_flag and interrupt_flag.is_set():
|
||||
task.cancel()
|
||||
raise ReqAbortException("请求被外部信号中断")
|
||||
await asyncio.sleep(interval_seconds)
|
||||
return await task
|
||||
|
||||
|
||||
class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]):
|
||||
"""提供统一请求执行骨架的 Provider 适配基类。"""
|
||||
|
||||
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
||||
"""获取对话响应。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 解析完成的统一响应对象。
|
||||
"""
|
||||
stream_response_handler = self._resolve_stream_response_handler(request)
|
||||
response_parser = self._resolve_response_parser(request)
|
||||
response, usage_record = await self._execute_response_request(
|
||||
request,
|
||||
stream_response_handler,
|
||||
response_parser,
|
||||
)
|
||||
return self._attach_usage_record(response, request.model_info, usage_record)
|
||||
|
||||
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
||||
"""获取文本嵌入。
|
||||
|
||||
Args:
|
||||
request: 统一嵌入请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 解析完成的统一嵌入响应。
|
||||
"""
|
||||
response, usage_record = await self._execute_embedding_request(request)
|
||||
return self._attach_usage_record(response, request.model_info, usage_record)
|
||||
|
||||
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
||||
"""获取音频转录。
|
||||
|
||||
Args:
|
||||
request: 统一音频转录请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 解析完成的统一音频转录响应。
|
||||
"""
|
||||
response, usage_record = await self._execute_audio_transcription_request(request)
|
||||
return self._attach_usage_record(response, request.model_info, usage_record)
|
||||
|
||||
def _resolve_stream_response_handler(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
) -> ProviderStreamResponseHandler[RawStreamT]:
|
||||
"""解析实际使用的流式响应处理器。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
ProviderStreamResponseHandler[RawStreamT]: 流式响应处理器。
|
||||
"""
|
||||
if request.stream_response_handler is not None:
|
||||
return cast(ProviderStreamResponseHandler[RawStreamT], request.stream_response_handler)
|
||||
return self._build_default_stream_response_handler(request)
|
||||
|
||||
def _resolve_response_parser(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
) -> ProviderResponseParser[RawResponseT]:
|
||||
"""解析实际使用的非流式响应解析器。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
ProviderResponseParser[RawResponseT]: 非流式响应解析器。
|
||||
"""
|
||||
if request.async_response_parser is not None:
|
||||
return cast(ProviderResponseParser[RawResponseT], request.async_response_parser)
|
||||
return self._build_default_response_parser(request)
|
||||
|
||||
@staticmethod
|
||||
def _build_usage_record(model_info: ModelInfo, usage_record: UsageTuple) -> UsageRecord:
|
||||
"""根据统一使用量三元组构建 `UsageRecord`。
|
||||
|
||||
Args:
|
||||
model_info: 模型信息。
|
||||
usage_record: 使用量三元组。
|
||||
|
||||
Returns:
|
||||
UsageRecord: 可直接挂载到 `APIResponse` 的使用记录对象。
|
||||
"""
|
||||
return UsageRecord(
|
||||
model_name=model_info.name,
|
||||
provider_name=model_info.api_provider,
|
||||
prompt_tokens=usage_record[0],
|
||||
completion_tokens=usage_record[1],
|
||||
total_tokens=usage_record[2],
|
||||
)
|
||||
|
||||
def _attach_usage_record(
|
||||
self,
|
||||
response: APIResponse,
|
||||
model_info: ModelInfo,
|
||||
usage_record: UsageTuple | None,
|
||||
) -> APIResponse:
|
||||
"""在响应对象上附加统一使用量信息。
|
||||
|
||||
Args:
|
||||
response: 已解析的统一响应对象。
|
||||
model_info: 模型信息。
|
||||
usage_record: 可选的使用量三元组。
|
||||
|
||||
Returns:
|
||||
APIResponse: 附加使用量后的响应对象。
|
||||
"""
|
||||
if usage_record is not None:
|
||||
response.usage = self._build_usage_record(model_info, usage_record)
|
||||
return response
|
||||
|
||||
@abstractmethod
|
||||
def _build_default_stream_response_handler(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
) -> ProviderStreamResponseHandler[RawStreamT]:
|
||||
"""构建默认流式响应处理器。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
ProviderStreamResponseHandler[RawStreamT]: 默认流式处理器。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _build_default_response_parser(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
) -> ProviderResponseParser[RawResponseT]:
|
||||
"""构建默认非流式响应解析器。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
ProviderResponseParser[RawResponseT]: 默认非流式解析器。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_response_request(
|
||||
self,
|
||||
request: ResponseRequest,
|
||||
stream_response_handler: ProviderStreamResponseHandler[RawStreamT],
|
||||
response_parser: ProviderResponseParser[RawResponseT],
|
||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||
"""执行 Provider 的文本/多模态响应请求。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
stream_response_handler: 流式响应处理器。
|
||||
response_parser: 非流式响应解析器。
|
||||
|
||||
Returns:
|
||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_embedding_request(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||
"""执行 Provider 的嵌入请求。
|
||||
|
||||
Args:
|
||||
request: 统一嵌入请求对象。
|
||||
|
||||
Returns:
|
||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_audio_transcription_request(
|
||||
self,
|
||||
request: AudioTranscriptionRequest,
|
||||
) -> Tuple[APIResponse, UsageTuple | None]:
|
||||
"""执行 Provider 的音频转录请求。
|
||||
|
||||
Args:
|
||||
request: 统一音频转录请求对象。
|
||||
|
||||
Returns:
|
||||
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
|
||||
|
||||
import asyncio
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import config_manager
|
||||
from src.config.model_configs import ModelInfo, APIProvider
|
||||
from ..payload_content.message import Message
|
||||
from ..payload_content.resp_format import RespFormat
|
||||
from ..payload_content.tool_option import ToolOption, ToolCall
|
||||
from src.config.model_configs import APIProvider, ModelInfo
|
||||
from src.llm_models.payload_content.message import Message
|
||||
from src.llm_models.payload_content.resp_format import RespFormat
|
||||
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
|
||||
|
||||
logger = get_logger("model_client_registry")
|
||||
|
||||
@@ -47,10 +48,10 @@ class APIResponse:
|
||||
reasoning_content: str | None = None
|
||||
"""推理内容"""
|
||||
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
"""工具调用 [(工具名称, 工具参数), ...]"""
|
||||
|
||||
embedding: list[float] | None = None
|
||||
embedding: List[float] | None = None
|
||||
"""嵌入向量"""
|
||||
|
||||
usage: UsageRecord | None = None
|
||||
@@ -60,6 +61,82 @@ class APIResponse:
|
||||
"""响应原始数据"""
|
||||
|
||||
|
||||
UsageTuple = Tuple[int, int, int]
|
||||
"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
|
||||
|
||||
StreamResponseHandler = Callable[
|
||||
[Any, asyncio.Event | None],
|
||||
Coroutine[Any, Any, Tuple["APIResponse", UsageTuple | None]],
|
||||
]
|
||||
"""统一的流式响应处理函数类型。"""
|
||||
|
||||
ResponseParser = Callable[[Any], Tuple["APIResponse", UsageTuple | None]]
|
||||
"""统一的非流式响应解析函数类型。"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ResponseRequest:
|
||||
"""统一的文本/多模态响应请求。"""
|
||||
|
||||
model_info: ModelInfo
|
||||
message_list: List[Message]
|
||||
tool_options: List[ToolOption] | None = None
|
||||
max_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
response_format: RespFormat | None = None
|
||||
stream_response_handler: StreamResponseHandler | None = None
|
||||
async_response_parser: ResponseParser | None = None
|
||||
interrupt_flag: asyncio.Event | None = None
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def copy_with(self, **changes: Any) -> "ResponseRequest":
|
||||
"""基于当前请求创建一个带局部变更的新请求。
|
||||
|
||||
Args:
|
||||
**changes: 需要覆盖的字段值。
|
||||
|
||||
Returns:
|
||||
ResponseRequest: 复制后的请求对象。
|
||||
"""
|
||||
payload = {
|
||||
"model_info": self.model_info,
|
||||
"message_list": list(self.message_list),
|
||||
"tool_options": None if self.tool_options is None else list(self.tool_options),
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"response_format": self.response_format,
|
||||
"stream_response_handler": self.stream_response_handler,
|
||||
"async_response_parser": self.async_response_parser,
|
||||
"interrupt_flag": self.interrupt_flag,
|
||||
"extra_params": dict(self.extra_params),
|
||||
}
|
||||
payload.update(changes)
|
||||
return ResponseRequest(**payload)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class EmbeddingRequest:
|
||||
"""统一的嵌入请求。"""
|
||||
|
||||
model_info: ModelInfo
|
||||
embedding_input: str
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AudioTranscriptionRequest:
|
||||
"""统一的音频转录请求。"""
|
||||
|
||||
model_info: ModelInfo
|
||||
audio_base64: str
|
||||
max_tokens: int | None = None
|
||||
extra_params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
ClientRequest = ResponseRequest | EmbeddingRequest | AudioTranscriptionRequest
|
||||
"""统一客户端请求类型。"""
|
||||
|
||||
|
||||
class BaseClient(ABC):
|
||||
"""
|
||||
基础客户端
|
||||
@@ -67,97 +144,82 @@ class BaseClient(ABC):
|
||||
|
||||
api_provider: APIProvider
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
def __init__(self, api_provider: APIProvider) -> None:
|
||||
"""初始化基础客户端。
|
||||
|
||||
Args:
|
||||
api_provider: API 提供商配置。
|
||||
"""
|
||||
self.api_provider = api_provider
|
||||
|
||||
@abstractmethod
|
||||
async def get_response(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
message_list: list[Message],
|
||||
tool_options: list[ToolOption] | None = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
response_format: RespFormat | None = None,
|
||||
stream_response_handler: Optional[
|
||||
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
|
||||
] = None,
|
||||
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取对话响应
|
||||
:param model_info: 模型信息
|
||||
:param message_list: 对话体
|
||||
:param tool_options: 工具选项(可选,默认为None)
|
||||
:param max_tokens: 最大token数(可选,默认为1024)
|
||||
:param temperature: 温度(可选,默认为0.7)
|
||||
:param response_format: 响应格式(可选,默认为 NotGiven )
|
||||
:param stream_response_handler: 流式响应处理函数(可选)
|
||||
:param async_response_parser: 响应解析函数(可选)
|
||||
:param interrupt_flag: 中断信号量(可选,默认为None)
|
||||
:return: (响应文本, 推理文本, 工具调用, 其他数据)
|
||||
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
||||
"""获取对话响应。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 统一响应对象。
|
||||
"""
|
||||
raise NotImplementedError("'get_response' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
embedding_input: str,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取文本嵌入
|
||||
:param model_info: 模型信息
|
||||
:param embedding_input: 嵌入输入文本
|
||||
:return: 嵌入响应
|
||||
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
||||
"""获取文本嵌入。
|
||||
|
||||
Args:
|
||||
request: 统一嵌入请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 嵌入响应。
|
||||
"""
|
||||
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
async def get_audio_transcriptions(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
audio_base64: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
获取音频转录
|
||||
:param model_info: 模型信息
|
||||
:param audio_base64: base64编码的音频数据
|
||||
:extra_params: 附加的请求参数
|
||||
:return: 音频转录响应
|
||||
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
||||
"""获取音频转录。
|
||||
|
||||
Args:
|
||||
request: 统一音频转录请求对象。
|
||||
|
||||
Returns:
|
||||
APIResponse: 音频转录响应。
|
||||
"""
|
||||
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
|
||||
|
||||
@abstractmethod
|
||||
def get_support_image_formats(self) -> list[str]:
|
||||
"""
|
||||
获取支持的图片格式
|
||||
:return: 支持的图片格式列表
|
||||
def get_support_image_formats(self) -> List[str]:
|
||||
"""获取支持的图片格式。
|
||||
|
||||
Returns:
|
||||
List[str]: 支持的图片格式列表。
|
||||
"""
|
||||
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||
|
||||
|
||||
class ClientRegistry:
|
||||
"""客户端注册表。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.client_registry: dict[str, type[BaseClient]] = {}
|
||||
"""初始化注册表并绑定配置重载回调。"""
|
||||
self.client_registry: Dict[str, Type[BaseClient]] = {}
|
||||
"""APIProvider.type -> BaseClient的映射表"""
|
||||
self.client_instance_cache: dict[str, BaseClient] = {}
|
||||
self.client_instance_cache: Dict[str, BaseClient] = {}
|
||||
"""APIProvider.name -> BaseClient的映射表"""
|
||||
config_manager.register_reload_callback(self.clear_client_instance_cache)
|
||||
|
||||
def register_client_class(self, client_type: str):
|
||||
"""
|
||||
注册API客户端类
|
||||
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
|
||||
"""注册 API 客户端类。
|
||||
|
||||
Args:
|
||||
client_class: API客户端类
|
||||
client_type: 客户端类型标识。
|
||||
|
||||
Returns:
|
||||
Callable[[Type[BaseClient]], Type[BaseClient]]: 装饰器函数。
|
||||
"""
|
||||
|
||||
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
|
||||
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
|
||||
if not issubclass(cls, BaseClient):
|
||||
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
|
||||
self.client_registry[client_type] = cls
|
||||
@@ -165,14 +227,15 @@ class ClientRegistry:
|
||||
|
||||
return decorator
|
||||
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
|
||||
"""
|
||||
获取注册的API客户端实例
|
||||
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
|
||||
"""获取注册的 API 客户端实例。
|
||||
|
||||
Args:
|
||||
api_provider: APIProvider实例
|
||||
force_new: 是否强制创建新实例(用于解决事件循环问题)
|
||||
api_provider: APIProvider 实例。
|
||||
force_new: 是否强制创建新实例。
|
||||
|
||||
Returns:
|
||||
BaseClient: 注册的API客户端实例
|
||||
BaseClient: 注册的 API 客户端实例。
|
||||
"""
|
||||
from . import ensure_client_type_loaded
|
||||
|
||||
@@ -194,6 +257,7 @@ class ClientRegistry:
|
||||
return self.client_instance_cache[api_provider.name]
|
||||
|
||||
def clear_client_instance_cache(self) -> None:
|
||||
"""清空客户端实例缓存。"""
|
||||
self.client_instance_cache.clear()
|
||||
logger.info("检测到配置重载,已清空LLM客户端实例缓存")
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
140
src/llm_models/openai_compat.py
Normal file
140
src/llm_models/openai_compat.py
Normal file
@@ -0,0 +1,140 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Mapping
|
||||
|
||||
from src.config.model_configs import APIProvider, OpenAICompatibleAuthType
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class OpenAICompatibleClientConfig:
|
||||
"""OpenAI 兼容客户端的基础配置。"""
|
||||
|
||||
api_key: str
|
||||
base_url: str
|
||||
default_headers: dict[str, str] = field(default_factory=dict)
|
||||
default_query: dict[str, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class OpenAICompatibleRequestOverrides:
|
||||
"""单次请求级别的附加配置。"""
|
||||
|
||||
extra_headers: dict[str, str] = field(default_factory=dict)
|
||||
extra_query: dict[str, object] = field(default_factory=dict)
|
||||
extra_body: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def normalize_openai_base_url(base_url: str) -> str:
|
||||
"""规范化 OpenAI 兼容接口的基础地址。
|
||||
|
||||
Args:
|
||||
base_url: 原始基础地址。
|
||||
|
||||
Returns:
|
||||
str: 去掉尾部斜杠后的地址。
|
||||
"""
|
||||
return base_url.rstrip("/")
|
||||
|
||||
|
||||
def _build_auth_header_value(prefix: str, api_key: str) -> str:
|
||||
"""构造鉴权请求头的值。
|
||||
|
||||
Args:
|
||||
prefix: 请求头前缀。
|
||||
api_key: 实际密钥。
|
||||
|
||||
Returns:
|
||||
str: 拼接完成的请求头值。
|
||||
"""
|
||||
normalized_prefix = prefix.strip()
|
||||
if not normalized_prefix:
|
||||
return api_key
|
||||
return f"{normalized_prefix} {api_key}"
|
||||
|
||||
|
||||
def build_openai_compatible_client_config(api_provider: APIProvider) -> OpenAICompatibleClientConfig:
|
||||
"""构建 OpenAI 兼容客户端配置。
|
||||
|
||||
Args:
|
||||
api_provider: API 提供商配置。
|
||||
|
||||
Returns:
|
||||
OpenAICompatibleClientConfig: 可直接用于初始化 SDK 客户端的配置。
|
||||
"""
|
||||
default_headers = dict(api_provider.default_headers)
|
||||
default_query: dict[str, object] = dict(api_provider.default_query)
|
||||
client_api_key = api_provider.api_key
|
||||
|
||||
if api_provider.auth_type == OpenAICompatibleAuthType.BEARER:
|
||||
if (
|
||||
api_provider.auth_header_name != "Authorization"
|
||||
or api_provider.auth_header_prefix.strip() != "Bearer"
|
||||
):
|
||||
client_api_key = ""
|
||||
default_headers[api_provider.auth_header_name] = _build_auth_header_value(
|
||||
prefix=api_provider.auth_header_prefix,
|
||||
api_key=api_provider.api_key,
|
||||
)
|
||||
elif api_provider.auth_type == OpenAICompatibleAuthType.HEADER:
|
||||
client_api_key = ""
|
||||
default_headers[api_provider.auth_header_name] = _build_auth_header_value(
|
||||
prefix=api_provider.auth_header_prefix,
|
||||
api_key=api_provider.api_key,
|
||||
)
|
||||
elif api_provider.auth_type == OpenAICompatibleAuthType.QUERY:
|
||||
client_api_key = ""
|
||||
default_query[api_provider.auth_query_name] = api_provider.api_key
|
||||
elif api_provider.auth_type == OpenAICompatibleAuthType.NONE:
|
||||
client_api_key = ""
|
||||
|
||||
return OpenAICompatibleClientConfig(
|
||||
api_key=client_api_key,
|
||||
base_url=normalize_openai_base_url(api_provider.base_url),
|
||||
default_headers=default_headers,
|
||||
default_query=default_query,
|
||||
)
|
||||
|
||||
|
||||
def _extract_mapping(value: Any) -> dict[str, Any]:
|
||||
"""将任意映射值规范化为普通字典。
|
||||
|
||||
Args:
|
||||
value: 原始输入值。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 规范化后的字典。非映射值时返回空字典。
|
||||
"""
|
||||
if isinstance(value, Mapping):
|
||||
return {str(key): item for key, item in value.items()}
|
||||
return {}
|
||||
|
||||
|
||||
def split_openai_request_overrides(
|
||||
extra_params: Mapping[str, Any] | None,
|
||||
*,
|
||||
reserved_body_keys: set[str] | None = None,
|
||||
) -> OpenAICompatibleRequestOverrides:
|
||||
"""拆分单次请求中的头、查询参数和请求体扩展字段。
|
||||
|
||||
Args:
|
||||
extra_params: 模型级别或请求级别的附加参数。
|
||||
reserved_body_keys: 由 SDK 原生参数承载、因此不应再进入 `extra_body` 的字段集合。
|
||||
|
||||
Returns:
|
||||
OpenAICompatibleRequestOverrides: 拆分后的请求覆盖配置。
|
||||
"""
|
||||
raw_params = dict(extra_params or {})
|
||||
extra_headers = _extract_mapping(raw_params.pop("headers", None))
|
||||
extra_query = _extract_mapping(raw_params.pop("query", None))
|
||||
extra_body = _extract_mapping(raw_params.pop("body", None))
|
||||
blocked_body_keys = reserved_body_keys or set()
|
||||
|
||||
for key, value in raw_params.items():
|
||||
if key in blocked_body_keys:
|
||||
continue
|
||||
extra_body[key] = value
|
||||
|
||||
return OpenAICompatibleRequestOverrides(
|
||||
extra_headers={key: str(value) for key, value in extra_headers.items()},
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
@@ -1,133 +1,280 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import List, Tuple
|
||||
|
||||
from .tool_option import ToolCall
|
||||
|
||||
|
||||
# 设计这系列类的目的是为未来可能的扩展做准备
|
||||
class RoleType(str, Enum):
|
||||
"""消息角色类型。"""
|
||||
|
||||
|
||||
class RoleType(Enum):
|
||||
System = "system"
|
||||
User = "user"
|
||||
Assistant = "assistant"
|
||||
Tool = "tool"
|
||||
|
||||
|
||||
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
|
||||
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"]
|
||||
"""默认支持的图片格式列表。"""
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class TextMessagePart:
|
||||
"""文本消息片段。"""
|
||||
|
||||
text: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行文本片段的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当文本为空时抛出。
|
||||
"""
|
||||
if self.text == "":
|
||||
raise ValueError("文本消息片段不能为空字符串")
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ImageMessagePart:
|
||||
"""Base64 图片消息片段。"""
|
||||
|
||||
image_format: str
|
||||
image_base64: str
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行图片片段的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当图片格式或 Base64 数据无效时抛出。
|
||||
"""
|
||||
if self.image_format.lower() not in SUPPORTED_IMAGE_FORMATS:
|
||||
raise ValueError("不受支持的图片格式")
|
||||
if not self.image_base64:
|
||||
raise ValueError("图片的 base64 编码不能为空")
|
||||
|
||||
@property
|
||||
def normalized_image_format(self) -> str:
|
||||
"""获取规范化后的图片格式。
|
||||
|
||||
Returns:
|
||||
str: 规范化后的图片格式。`jpg` 会被统一为 `jpeg`。
|
||||
"""
|
||||
image_format = self.image_format.lower()
|
||||
if image_format in {"jpg", "jpeg"}:
|
||||
return "jpeg"
|
||||
return image_format
|
||||
|
||||
|
||||
MessagePart = TextMessagePart | ImageMessagePart
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Message:
|
||||
def __init__(
|
||||
self,
|
||||
role: RoleType,
|
||||
content: str | list[tuple[str, str] | str],
|
||||
tool_call_id: str | None = None,
|
||||
tool_calls: Optional[List[ToolCall]] = None,
|
||||
):
|
||||
"""统一消息模型。"""
|
||||
|
||||
role: RoleType
|
||||
parts: List[MessagePart] = field(default_factory=list)
|
||||
tool_call_id: str | None = None
|
||||
tool_calls: List[ToolCall] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行消息对象的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当消息内容或工具调用信息不完整时抛出。
|
||||
"""
|
||||
初始化消息对象
|
||||
(不应直接修改Message类,而应使用MessageBuilder类来构建对象)
|
||||
if not self.parts and not (self.role == RoleType.Assistant and self.tool_calls):
|
||||
raise ValueError("消息内容不能为空")
|
||||
if self.role == RoleType.Tool and not self.tool_call_id:
|
||||
raise ValueError("Tool 角色的工具调用 ID 不能为空")
|
||||
|
||||
@property
|
||||
def content(self) -> str | List[Tuple[str, str] | str]:
|
||||
"""获取兼容旧逻辑的内容视图。
|
||||
|
||||
Returns:
|
||||
str | List[Tuple[str, str] | str]: 当仅包含一个文本片段时返回字符串,
|
||||
否则返回混合列表,其中图片片段表示为 `(format, base64)` 元组。
|
||||
"""
|
||||
self.role: RoleType = role
|
||||
self.content: str | list[tuple[str, str] | str] = content
|
||||
self.tool_call_id: str | None = tool_call_id
|
||||
self.tool_calls: Optional[List[ToolCall]] = tool_calls
|
||||
if len(self.parts) == 1 and isinstance(self.parts[0], TextMessagePart):
|
||||
return self.parts[0].text
|
||||
content_items: List[Tuple[str, str] | str] = []
|
||||
for part in self.parts:
|
||||
if isinstance(part, TextMessagePart):
|
||||
content_items.append(part.text)
|
||||
else:
|
||||
content_items.append((part.image_format, part.image_base64))
|
||||
return content_items
|
||||
|
||||
def get_text_content(self) -> str:
|
||||
"""提取消息中的所有文本片段。
|
||||
|
||||
Returns:
|
||||
str: 以原始顺序拼接后的文本内容。
|
||||
"""
|
||||
return "".join(part.text for part in self.parts if isinstance(part, TextMessagePart))
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""生成便于调试的字符串表示。
|
||||
|
||||
Returns:
|
||||
str: 当前消息对象的可读摘要。
|
||||
"""
|
||||
return (
|
||||
f"Role: {self.role}, Content: {self.content}, "
|
||||
f"Role: {self.role}, Parts: {self.parts}, "
|
||||
f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}"
|
||||
)
|
||||
|
||||
|
||||
class MessageBuilder:
|
||||
def __init__(self):
|
||||
"""消息构建器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化构建器。"""
|
||||
self.__role: RoleType = RoleType.User
|
||||
self.__content: list[tuple[str, str] | str] = []
|
||||
self.__parts: List[MessagePart] = []
|
||||
self.__tool_call_id: str | None = None
|
||||
self.__tool_calls: Optional[List[ToolCall]] = None
|
||||
self.__tool_calls: List[ToolCall] | None = None
|
||||
|
||||
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
|
||||
"""
|
||||
设置角色(默认为User)
|
||||
:param role: 角色
|
||||
:return: MessageBuilder对象
|
||||
"""设置消息角色。
|
||||
|
||||
Args:
|
||||
role: 目标角色,默认为 `user`。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
self.__role = role
|
||||
return self
|
||||
|
||||
def add_text_part(self, text: str) -> "MessageBuilder":
|
||||
"""追加文本片段。
|
||||
|
||||
Args:
|
||||
text: 文本内容。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
self.__parts.append(TextMessagePart(text=text))
|
||||
return self
|
||||
|
||||
def add_text_content(self, text: str) -> "MessageBuilder":
|
||||
"""追加文本片段。
|
||||
|
||||
Args:
|
||||
text: 文本内容。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
添加文本内容
|
||||
:param text: 文本内容
|
||||
:return: MessageBuilder对象
|
||||
return self.add_text_part(text)
|
||||
|
||||
def add_image_base64_part(
|
||||
self,
|
||||
image_format: str,
|
||||
image_base64: str,
|
||||
support_formats: List[str] = SUPPORTED_IMAGE_FORMATS,
|
||||
) -> "MessageBuilder":
|
||||
"""追加 Base64 图片片段。
|
||||
|
||||
Args:
|
||||
image_format: 图片格式。
|
||||
image_base64: 图片的 Base64 编码。
|
||||
support_formats: 允许的图片格式列表。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当图片格式不被支持时抛出。
|
||||
"""
|
||||
self.__content.append(text)
|
||||
if image_format.lower() not in support_formats:
|
||||
raise ValueError("不受支持的图片格式")
|
||||
self.__parts.append(ImageMessagePart(image_format=image_format, image_base64=image_base64))
|
||||
return self
|
||||
|
||||
def add_image_content(
|
||||
self,
|
||||
image_format: str,
|
||||
image_base64: str,
|
||||
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
|
||||
support_formats: List[str] = SUPPORTED_IMAGE_FORMATS,
|
||||
) -> "MessageBuilder":
|
||||
"""
|
||||
添加图片内容
|
||||
:param image_format: 图片格式
|
||||
:param image_base64: 图片的base64编码
|
||||
:return: MessageBuilder对象
|
||||
"""
|
||||
if image_format.lower() not in support_formats:
|
||||
raise ValueError("不受支持的图片格式")
|
||||
if not image_base64:
|
||||
raise ValueError("图片的base64编码不能为空")
|
||||
self.__content.append((image_format, image_base64))
|
||||
return self
|
||||
"""追加 Base64 图片片段。
|
||||
|
||||
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
|
||||
Args:
|
||||
image_format: 图片格式。
|
||||
image_base64: 图片的 Base64 编码。
|
||||
support_formats: 允许的图片格式列表。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
添加工具调用指令(调用时请确保已设置为Tool角色)
|
||||
:param tool_call_id: 工具调用指令的id
|
||||
:return: MessageBuilder对象
|
||||
return self.add_image_base64_part(
|
||||
image_format=image_format,
|
||||
image_base64=image_base64,
|
||||
support_formats=support_formats,
|
||||
)
|
||||
|
||||
def set_tool_call_id(self, tool_call_id: str) -> "MessageBuilder":
|
||||
"""设置工具结果消息引用的工具调用 ID。
|
||||
|
||||
Args:
|
||||
tool_call_id: 工具调用 ID。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当当前角色不是 `tool` 或 ID 为空时抛出。
|
||||
"""
|
||||
if self.__role != RoleType.Tool:
|
||||
raise ValueError("仅当角色为Tool时才能添加工具调用ID")
|
||||
raise ValueError("仅当角色为 Tool 时才能设置工具调用 ID")
|
||||
if not tool_call_id:
|
||||
raise ValueError("工具调用ID不能为空")
|
||||
raise ValueError("工具调用 ID 不能为空")
|
||||
self.__tool_call_id = tool_call_id
|
||||
return self
|
||||
|
||||
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
|
||||
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
|
||||
"""设置工具结果消息引用的工具调用 ID。
|
||||
|
||||
Args:
|
||||
tool_call_id: 工具调用 ID。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
"""
|
||||
设置助手消息的工具调用列表
|
||||
:param tool_calls: 工具调用列表
|
||||
:return: MessageBuilder对象
|
||||
return self.set_tool_call_id(tool_call_id)
|
||||
|
||||
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
|
||||
"""设置助手消息中的工具调用列表。
|
||||
|
||||
Args:
|
||||
tool_calls: 工具调用列表。
|
||||
|
||||
Returns:
|
||||
MessageBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当当前角色不是 `assistant` 或列表为空时抛出。
|
||||
"""
|
||||
if self.__role != RoleType.Assistant:
|
||||
raise ValueError("仅当角色为Assistant时才能设置工具调用列表")
|
||||
raise ValueError("仅当角色为 Assistant 时才能设置工具调用列表")
|
||||
if not tool_calls:
|
||||
raise ValueError("工具调用列表不能为空")
|
||||
self.__tool_calls = tool_calls
|
||||
self.__tool_calls = list(tool_calls)
|
||||
return self
|
||||
|
||||
def build(self) -> Message:
|
||||
"""
|
||||
构建消息对象
|
||||
:return: Message对象
|
||||
"""
|
||||
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
|
||||
raise ValueError("内容不能为空")
|
||||
if self.__role == RoleType.Tool and self.__tool_call_id is None:
|
||||
raise ValueError("Tool角色的工具调用ID不能为空")
|
||||
"""构建消息对象。
|
||||
|
||||
Returns:
|
||||
Message: 构建完成的消息对象。
|
||||
"""
|
||||
return Message(
|
||||
role=self.__role,
|
||||
content=(
|
||||
self.__content[0]
|
||||
if (len(self.__content) == 1 and isinstance(self.__content[0], str))
|
||||
else self.__content
|
||||
),
|
||||
parts=list(self.__parts),
|
||||
tool_call_id=self.__tool_call_id,
|
||||
tool_calls=self.__tool_calls,
|
||||
tool_calls=list(self.__tool_calls) if self.__tool_calls else None,
|
||||
)
|
||||
|
||||
@@ -1,51 +1,40 @@
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
from typing import Any, Dict, List, Mapping, Optional, Type, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict, Required
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class RespFormatType(Enum):
|
||||
TEXT = "text" # 文本
|
||||
JSON_OBJ = "json_object" # JSON
|
||||
JSON_SCHEMA = "json_schema" # JSON Schema
|
||||
"""响应格式类型。"""
|
||||
|
||||
TEXT = "text"
|
||||
JSON_OBJ = "json_object"
|
||||
JSON_SCHEMA = "json_schema"
|
||||
|
||||
|
||||
class JsonSchema(TypedDict, total=False):
|
||||
"""内部使用的 JSON Schema 包装结构。"""
|
||||
|
||||
name: Required[str]
|
||||
"""
|
||||
The name of the response format.
|
||||
|
||||
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
|
||||
of 64.
|
||||
"""
|
||||
|
||||
description: Optional[str]
|
||||
"""
|
||||
A description of what the response format is for, used by the model to determine
|
||||
how to respond in the format.
|
||||
"""
|
||||
|
||||
schema: dict[str, object]
|
||||
"""
|
||||
The schema for the response format, described as a JSON Schema object. Learn how
|
||||
to build JSON schemas [here](https://json-schema.org/).
|
||||
"""
|
||||
|
||||
schema: Dict[str, Any]
|
||||
strict: Optional[bool]
|
||||
"""
|
||||
Whether to enable strict schema adherence when generating the output. If set to
|
||||
true, the model will always follow the exact schema defined in the `schema`
|
||||
field. Only a subset of JSON Schema is supported when `strict` is `true`. To
|
||||
learn more, read the
|
||||
[Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
|
||||
"""
|
||||
|
||||
|
||||
def _json_schema_type_check(instance) -> str | None:
|
||||
def _json_schema_type_check(instance: Mapping[str, Any]) -> str | None:
|
||||
"""检查 JSON Schema 包装结构是否合法。
|
||||
|
||||
Args:
|
||||
instance: 待检查的 JSON Schema 包装字典。
|
||||
|
||||
Returns:
|
||||
str | None: 不合法时返回错误信息,合法时返回 `None`。
|
||||
"""
|
||||
if "name" not in instance:
|
||||
return "schema必须包含'name'字段"
|
||||
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
|
||||
if not isinstance(instance["name"], str) or instance["name"].strip() == "":
|
||||
return "schema的'name'字段必须是非空字符串"
|
||||
if "description" in instance and (
|
||||
not isinstance(instance["description"], str) or instance["description"].strip() == ""
|
||||
@@ -53,164 +42,198 @@ def _json_schema_type_check(instance) -> str | None:
|
||||
return "schema的'description'字段只能填入非空字符串"
|
||||
if "schema" not in instance:
|
||||
return "schema必须包含'schema'字段"
|
||||
elif not isinstance(instance["schema"], dict):
|
||||
if not isinstance(instance["schema"], dict):
|
||||
return "schema的'schema'字段必须是字典,详见https://json-schema.org/"
|
||||
if "strict" in instance and not isinstance(instance["strict"], bool):
|
||||
return "schema的'strict'字段只能填入布尔值"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]:
|
||||
"""
|
||||
递归移除JSON Schema中的title字段
|
||||
def _remove_title(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]:
|
||||
"""递归移除 JSON Schema 中的 `title` 字段。
|
||||
|
||||
Args:
|
||||
schema: 待处理的 Schema 树。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | List[Any]: 处理后的 Schema 树。
|
||||
"""
|
||||
if isinstance(schema, list):
|
||||
# 如果当前Schema是列表,则对所有dict/list子元素递归调用
|
||||
for idx, item in enumerate(schema):
|
||||
for index, item in enumerate(schema):
|
||||
if isinstance(item, (dict, list)):
|
||||
schema[idx] = _remove_title(item)
|
||||
elif isinstance(schema, dict):
|
||||
# 是字典,移除title字段,并对所有dict/list子元素递归调用
|
||||
if "title" in schema:
|
||||
del schema["title"]
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
schema[key] = _remove_title(value)
|
||||
schema[index] = _remove_title(item)
|
||||
return schema
|
||||
|
||||
if "title" in schema:
|
||||
del schema["title"]
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
schema[key] = _remove_title(value)
|
||||
return schema
|
||||
|
||||
|
||||
def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
链接JSON Schema中的definitions字段
|
||||
def _link_definitions(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""展开 Schema 中的本地 `$defs`/`$ref` 引用。
|
||||
|
||||
Args:
|
||||
schema: 待处理的根 Schema。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 展开后的 Schema。
|
||||
"""
|
||||
|
||||
def link_definitions_recursive(
|
||||
path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
递归链接JSON Schema中的definitions字段
|
||||
:param path: 当前路径
|
||||
:param sub_schema: 子Schema
|
||||
:param defs: Schema定义集
|
||||
:return:
|
||||
path: str,
|
||||
sub_schema: Dict[str, Any] | List[Any],
|
||||
definitions: Dict[str, Any],
|
||||
) -> Dict[str, Any] | List[Any]:
|
||||
"""递归展开局部定义。
|
||||
|
||||
Args:
|
||||
path: 当前递归路径。
|
||||
sub_schema: 当前子 Schema。
|
||||
definitions: 已收集的定义字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | List[Any]: 展开后的子 Schema。
|
||||
"""
|
||||
if isinstance(sub_schema, list):
|
||||
# 如果当前Schema是列表,则遍历每个元素
|
||||
for i in range(len(sub_schema)):
|
||||
if isinstance(sub_schema[i], dict):
|
||||
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
|
||||
else:
|
||||
# 否则为字典
|
||||
if "$defs" in sub_schema:
|
||||
# 如果当前Schema有$def字段,则将其添加到defs中
|
||||
key_prefix = f"{path}/$defs/"
|
||||
for key, value in sub_schema["$defs"].items():
|
||||
def_key = key_prefix + key
|
||||
if def_key not in defs:
|
||||
defs[def_key] = value
|
||||
del sub_schema["$defs"]
|
||||
if "$ref" in sub_schema:
|
||||
# 如果当前Schema有$ref字段,则将其替换为defs中的定义
|
||||
def_key = sub_schema["$ref"]
|
||||
if def_key in defs:
|
||||
sub_schema = defs[def_key]
|
||||
else:
|
||||
raise ValueError(f"Schema中引用的定义'{def_key}'不存在")
|
||||
# 遍历键值对
|
||||
for key, value in sub_schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
# 如果当前值是字典或列表,则递归调用
|
||||
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
|
||||
for index, item in enumerate(sub_schema):
|
||||
if isinstance(item, (dict, list)):
|
||||
sub_schema[index] = link_definitions_recursive(f"{path}/{index}", item, definitions)
|
||||
return sub_schema
|
||||
|
||||
if "$defs" in sub_schema:
|
||||
key_prefix = f"{path}/$defs/"
|
||||
defs_payload = cast(Dict[str, Any], sub_schema["$defs"])
|
||||
for key, value in defs_payload.items():
|
||||
definition_key = key_prefix + key
|
||||
if definition_key not in definitions:
|
||||
definitions[definition_key] = value
|
||||
del sub_schema["$defs"]
|
||||
|
||||
if "$ref" in sub_schema:
|
||||
definition_key = cast(str, sub_schema["$ref"])
|
||||
if definition_key in definitions:
|
||||
return definitions[definition_key]
|
||||
raise ValueError(f"Schema中引用的定义'{definition_key}'不存在")
|
||||
|
||||
for key, value in sub_schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, definitions)
|
||||
return sub_schema
|
||||
|
||||
return link_definitions_recursive("#", schema, {})
|
||||
return cast(Dict[str, Any], link_definitions_recursive("#", schema, {}))
|
||||
|
||||
|
||||
def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
递归移除JSON Schema中的$defs字段
|
||||
def _remove_defs(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]:
|
||||
"""递归移除 JSON Schema 中的 `$defs` 字段。
|
||||
|
||||
Args:
|
||||
schema: 待处理的 Schema 树。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | List[Any]: 处理后的 Schema 树。
|
||||
"""
|
||||
if isinstance(schema, list):
|
||||
# 如果当前Schema是列表,则对所有dict/list子元素递归调用
|
||||
for idx, item in enumerate(schema):
|
||||
for index, item in enumerate(schema):
|
||||
if isinstance(item, (dict, list)):
|
||||
schema[idx] = _remove_title(item)
|
||||
elif isinstance(schema, dict):
|
||||
# 是字典,移除title字段,并对所有dict/list子元素递归调用
|
||||
if "$defs" in schema:
|
||||
del schema["$defs"]
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
schema[key] = _remove_title(value)
|
||||
schema[index] = _remove_defs(item)
|
||||
return schema
|
||||
|
||||
if "$defs" in schema:
|
||||
del schema["$defs"]
|
||||
for key, value in schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
schema[key] = _remove_defs(value)
|
||||
return schema
|
||||
|
||||
|
||||
class RespFormat:
|
||||
"""
|
||||
响应格式
|
||||
"""
|
||||
"""统一响应格式定义。"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_schema_from_model(schema):
|
||||
json_schema = {
|
||||
"name": schema.__name__,
|
||||
"schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
|
||||
def _generate_schema_from_model(schema_model: Type[BaseModel]) -> JsonSchema:
|
||||
"""从 Pydantic 模型生成内部 JSON Schema 包装结构。
|
||||
|
||||
Args:
|
||||
schema_model: Pydantic 模型类。
|
||||
|
||||
Returns:
|
||||
JsonSchema: 内部统一 JSON Schema 包装结构。
|
||||
"""
|
||||
schema_tree = deepcopy(schema_model.model_json_schema())
|
||||
json_schema: JsonSchema = {
|
||||
"name": schema_model.__name__,
|
||||
"schema": cast(
|
||||
Dict[str, Any],
|
||||
_remove_defs(_link_definitions(cast(Dict[str, Any], _remove_title(schema_tree)))),
|
||||
),
|
||||
"strict": False,
|
||||
}
|
||||
if schema.__doc__:
|
||||
json_schema["description"] = schema.__doc__
|
||||
if schema_model.__doc__:
|
||||
json_schema["description"] = schema_model.__doc__
|
||||
return json_schema
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
format_type: RespFormatType = RespFormatType.TEXT,
|
||||
schema: type | JsonSchema | None = None,
|
||||
):
|
||||
"""
|
||||
响应格式
|
||||
:param format_type: 响应格式类型(默认为文本)
|
||||
:param schema: 模板类或JsonSchema(仅当format_type为JSON Schema时有效)
|
||||
schema: Type[BaseModel] | JsonSchema | None = None,
|
||||
) -> None:
|
||||
"""初始化响应格式对象。
|
||||
|
||||
Args:
|
||||
format_type: 响应格式类型。
|
||||
schema: 模型类或 JSON Schema 包装结构,仅 `JSON_SCHEMA` 模式使用。
|
||||
"""
|
||||
self.format_type: RespFormatType = format_type
|
||||
self.schema_source: Type[BaseModel] | JsonSchema | None = schema
|
||||
self.schema: JsonSchema | None = None
|
||||
|
||||
if format_type == RespFormatType.JSON_SCHEMA:
|
||||
if schema is None:
|
||||
raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空")
|
||||
if isinstance(schema, dict):
|
||||
if check_msg := _json_schema_type_check(schema):
|
||||
raise ValueError(f"schema格式不正确,{check_msg}")
|
||||
if format_type != RespFormatType.JSON_SCHEMA:
|
||||
return
|
||||
if schema is None:
|
||||
raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空")
|
||||
if isinstance(schema, dict):
|
||||
if check_msg := _json_schema_type_check(schema):
|
||||
raise ValueError(f"schema格式不正确,{check_msg}")
|
||||
self.schema = cast(JsonSchema, deepcopy(schema))
|
||||
return
|
||||
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
||||
try:
|
||||
self.schema = self._generate_schema_from_model(schema)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n"
|
||||
f"{schema.__name__}:\n"
|
||||
) from exc
|
||||
return
|
||||
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
|
||||
|
||||
self.schema = schema
|
||||
elif issubclass(schema, BaseModel):
|
||||
try:
|
||||
json_schema = self._generate_schema_from_model(schema)
|
||||
def get_schema_object(self) -> Dict[str, Any] | None:
|
||||
"""获取内部包装中的对象级 JSON Schema。
|
||||
|
||||
self.schema = json_schema
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n"
|
||||
f"{schema.__name__}:\n"
|
||||
) from e
|
||||
else:
|
||||
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
|
||||
else:
|
||||
self.schema = None
|
||||
|
||||
def to_dict(self):
|
||||
Returns:
|
||||
Dict[str, Any] | None: 对象级 JSON Schema;不存在时返回 `None`。
|
||||
"""
|
||||
将响应格式转换为字典
|
||||
:return: 字典
|
||||
if self.schema is None:
|
||||
return None
|
||||
schema_payload = self.schema.get("schema")
|
||||
if isinstance(schema_payload, dict):
|
||||
return cast(Dict[str, Any], deepcopy(schema_payload))
|
||||
return None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""将响应格式转换为字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 序列化后的响应格式字典。
|
||||
"""
|
||||
if self.schema:
|
||||
return {
|
||||
"format_type": self.format_type.value,
|
||||
"schema": self.schema,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"format_type": self.format_type.value,
|
||||
}
|
||||
return {
|
||||
"format_type": self.format_type.value,
|
||||
}
|
||||
|
||||
@@ -1,83 +1,368 @@
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple, TypeAlias, cast
|
||||
|
||||
|
||||
class ToolParamType(Enum):
|
||||
class ToolParamType(str, Enum):
|
||||
"""工具参数类型。"""
|
||||
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
NUMBER = "number"
|
||||
FLOAT = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
LegacyToolParameterTuple = Tuple[str, ToolParamType, str, bool, List[str] | None]
|
||||
"""旧版工具参数元组格式。"""
|
||||
|
||||
|
||||
def normalize_tool_param_type(raw_value: ToolParamType | str | None) -> ToolParamType:
|
||||
"""将任意输入值规范化为内部工具参数类型。
|
||||
|
||||
Args:
|
||||
raw_value: 原始参数类型值。
|
||||
|
||||
Returns:
|
||||
ToolParamType: 规范化后的参数类型。未知值会回退为 `STRING`。
|
||||
"""
|
||||
工具调用参数类型
|
||||
if isinstance(raw_value, ToolParamType):
|
||||
return raw_value
|
||||
|
||||
normalized_value = str(raw_value or "").strip().lower()
|
||||
if normalized_value in {"integer", "int"}:
|
||||
return ToolParamType.INTEGER
|
||||
if normalized_value in {"number", "float"}:
|
||||
return ToolParamType.NUMBER
|
||||
if normalized_value in {"boolean", "bool"}:
|
||||
return ToolParamType.BOOLEAN
|
||||
if normalized_value == "array":
|
||||
return ToolParamType.ARRAY
|
||||
if normalized_value == "object":
|
||||
return ToolParamType.OBJECT
|
||||
return ToolParamType.STRING
|
||||
|
||||
|
||||
def _is_object_schema(schema: Dict[str, Any]) -> bool:
|
||||
"""判断输入字典是否已经是对象级 JSON Schema。
|
||||
|
||||
Args:
|
||||
schema: 待判断的字典。
|
||||
|
||||
Returns:
|
||||
bool: 为对象级 JSON Schema 时返回 `True`。
|
||||
"""
|
||||
|
||||
STRING = "string" # 字符串
|
||||
INTEGER = "integer" # 整型
|
||||
FLOAT = "float" # 浮点型
|
||||
BOOLEAN = "bool" # 布尔型
|
||||
return schema.get("type") == "object" or "properties" in schema or "required" in schema
|
||||
|
||||
|
||||
def _build_parameters_schema_from_property_map(property_map: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""将属性映射转换为对象级 JSON Schema。
|
||||
|
||||
Args:
|
||||
property_map: 仅包含属性定义的映射。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 对象级 JSON Schema。
|
||||
"""
|
||||
required_names: List[str] = []
|
||||
normalized_properties: Dict[str, Any] = {}
|
||||
for property_name, property_schema in property_map.items():
|
||||
if not isinstance(property_schema, dict):
|
||||
continue
|
||||
|
||||
property_schema_copy = deepcopy(property_schema)
|
||||
is_required = bool(property_schema_copy.pop("required", False))
|
||||
if is_required:
|
||||
required_names.append(str(property_name))
|
||||
normalized_properties[str(property_name)] = property_schema_copy
|
||||
|
||||
parameters_schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": normalized_properties,
|
||||
}
|
||||
if required_names:
|
||||
parameters_schema["required"] = required_names
|
||||
return parameters_schema
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolParam:
|
||||
"""
|
||||
工具调用参数
|
||||
"""
|
||||
"""工具参数定义。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str
|
||||
param_type: ToolParamType
|
||||
description: str
|
||||
required: bool
|
||||
enum_values: List[Any] | None = None
|
||||
items_schema: Dict[str, Any] | None = None
|
||||
properties: Dict[str, Dict[str, Any]] | None = None
|
||||
required_properties: List[str] = field(default_factory=list)
|
||||
additional_properties: bool | Dict[str, Any] | None = None
|
||||
default: Any = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行参数定义的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当参数名称或复杂类型定义不合法时抛出。
|
||||
"""
|
||||
if not self.name:
|
||||
raise ValueError("参数名称不能为空")
|
||||
if self.param_type == ToolParamType.ARRAY and self.items_schema is None:
|
||||
raise ValueError("数组参数必须提供 items_schema")
|
||||
if self.param_type == ToolParamType.OBJECT and self.properties is None:
|
||||
self.properties = {}
|
||||
|
||||
@classmethod
|
||||
def from_legacy_tuple(cls, parameter: LegacyToolParameterTuple) -> "ToolParam":
|
||||
"""从旧版五元组参数定义构建工具参数。
|
||||
|
||||
Args:
|
||||
parameter: 旧版参数元组。
|
||||
|
||||
Returns:
|
||||
ToolParam: 规范化后的工具参数对象。
|
||||
"""
|
||||
return cls(
|
||||
name=parameter[0],
|
||||
param_type=parameter[1],
|
||||
description=parameter[2],
|
||||
required=parameter[3],
|
||||
enum_values=parameter[4],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
name: str,
|
||||
param_type: ToolParamType,
|
||||
description: str,
|
||||
required: bool,
|
||||
enum_values: list[str] | None = None,
|
||||
):
|
||||
parameter_schema: Dict[str, Any],
|
||||
*,
|
||||
required: bool = False,
|
||||
) -> "ToolParam":
|
||||
"""从属性级 JSON Schema 或结构化参数字典构建工具参数。
|
||||
|
||||
Args:
|
||||
name: 参数名称。
|
||||
parameter_schema: 参数对应的 Schema 或结构化定义。
|
||||
required: 参数是否必填。
|
||||
|
||||
Returns:
|
||||
ToolParam: 规范化后的工具参数对象。
|
||||
"""
|
||||
初始化工具调用参数
|
||||
(不应直接修改ToolParam类,而应使用ToolOptionBuilder类来构建对象)
|
||||
:param name: 参数名称
|
||||
:param param_type: 参数类型
|
||||
:param description: 参数描述
|
||||
:param required: 是否必填
|
||||
raw_required_properties = parameter_schema.get("required_properties")
|
||||
if raw_required_properties is None and isinstance(parameter_schema.get("required"), list):
|
||||
raw_required_properties = parameter_schema.get("required")
|
||||
return cls(
|
||||
name=name,
|
||||
param_type=normalize_tool_param_type(parameter_schema.get("param_type") or parameter_schema.get("type")),
|
||||
description=str(parameter_schema.get("description", "") or ""),
|
||||
required=required,
|
||||
enum_values=deepcopy(parameter_schema.get("enum_values") or parameter_schema.get("enum")),
|
||||
items_schema=deepcopy(parameter_schema.get("items_schema") or parameter_schema.get("items")),
|
||||
properties=deepcopy(parameter_schema.get("properties")),
|
||||
required_properties=list(raw_required_properties or []),
|
||||
additional_properties=deepcopy(
|
||||
parameter_schema["additional_properties"]
|
||||
if "additional_properties" in parameter_schema
|
||||
else parameter_schema.get("additionalProperties")
|
||||
),
|
||||
default=deepcopy(parameter_schema.get("default")),
|
||||
)
|
||||
|
||||
def to_json_schema(self) -> Dict[str, Any]:
|
||||
"""将参数定义转换为 JSON Schema。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 参数对应的 JSON Schema 片段。
|
||||
"""
|
||||
self.name: str = name
|
||||
self.param_type: ToolParamType = param_type
|
||||
self.description: str = description
|
||||
self.required: bool = required
|
||||
self.enum_values: list[str] | None = enum_values
|
||||
schema: Dict[str, Any] = {
|
||||
"type": self.param_type.value,
|
||||
"description": self.description,
|
||||
}
|
||||
if self.enum_values:
|
||||
schema["enum"] = list(self.enum_values)
|
||||
if self.default is not None:
|
||||
schema["default"] = deepcopy(self.default)
|
||||
if self.param_type == ToolParamType.ARRAY and self.items_schema is not None:
|
||||
schema["items"] = deepcopy(self.items_schema)
|
||||
if self.param_type == ToolParamType.OBJECT:
|
||||
schema["properties"] = deepcopy(self.properties or {})
|
||||
if self.required_properties:
|
||||
schema["required"] = list(self.required_properties)
|
||||
if self.additional_properties is not None:
|
||||
schema["additionalProperties"] = deepcopy(self.additional_properties)
|
||||
return schema
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolOption:
|
||||
"""
|
||||
工具调用项
|
||||
"""
|
||||
"""工具定义。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
params: list[ToolParam] | None = None,
|
||||
):
|
||||
name: str
|
||||
description: str
|
||||
params: List[ToolParam] | None = None
|
||||
parameters_schema_override: Dict[str, Any] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行工具定义的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当工具名称、描述或参数 Schema 不合法时抛出。
|
||||
"""
|
||||
初始化工具调用项
|
||||
(不应直接修改ToolOption类,而应使用ToolOptionBuilder类来构建对象)
|
||||
:param name: 工具名称
|
||||
:param description: 工具描述
|
||||
:param params: 工具参数列表
|
||||
if not self.name:
|
||||
raise ValueError("工具名称不能为空")
|
||||
if not self.description:
|
||||
raise ValueError("工具描述不能为空")
|
||||
if self.parameters_schema_override is not None:
|
||||
schema_type = self.parameters_schema_override.get("type")
|
||||
if schema_type != "object":
|
||||
raise ValueError("工具参数 Schema 必须是 object 类型")
|
||||
|
||||
@classmethod
|
||||
def from_definition(cls, definition: Dict[str, Any]) -> "ToolOption":
|
||||
"""从任意支持的工具定义字典构建内部工具对象。
|
||||
|
||||
支持以下输入形状:
|
||||
- `{"name", "description", "parameters_schema"}`
|
||||
- `{"name", "description", "parameters"}`
|
||||
- OpenAI function tool:`{"type": "function", "function": {...}}`
|
||||
- 仅属性映射的对象参数定义:`{"query": {"type": "string"}}`
|
||||
|
||||
Args:
|
||||
definition: 原始工具定义字典。
|
||||
|
||||
Returns:
|
||||
ToolOption: 规范化后的工具定义对象。
|
||||
|
||||
Raises:
|
||||
ValueError: 当工具定义缺少必要字段时抛出。
|
||||
"""
|
||||
self.name: str = name
|
||||
self.description: str = description
|
||||
self.params: list[ToolParam] | None = params
|
||||
if definition.get("type") == "function" and isinstance(definition.get("function"), dict):
|
||||
function_definition = cast(Dict[str, Any], definition["function"])
|
||||
return cls.from_definition(
|
||||
{
|
||||
"name": function_definition.get("name", ""),
|
||||
"description": function_definition.get("description", ""),
|
||||
"parameters_schema": function_definition.get("parameters"),
|
||||
}
|
||||
)
|
||||
|
||||
name = str(definition.get("name", "") or "").strip()
|
||||
description = str(definition.get("description", "") or "").strip()
|
||||
if not name:
|
||||
raise ValueError("工具定义缺少 name")
|
||||
if not description:
|
||||
description = f"工具 {name}"
|
||||
|
||||
parameters_schema = definition.get("parameters_schema")
|
||||
if isinstance(parameters_schema, dict):
|
||||
normalized_schema = deepcopy(parameters_schema)
|
||||
if not _is_object_schema(normalized_schema):
|
||||
normalized_schema = _build_parameters_schema_from_property_map(normalized_schema)
|
||||
return cls(
|
||||
name=name,
|
||||
description=description,
|
||||
params=None,
|
||||
parameters_schema_override=normalized_schema,
|
||||
)
|
||||
|
||||
raw_parameters = definition.get("parameters")
|
||||
if isinstance(raw_parameters, dict):
|
||||
normalized_schema = deepcopy(raw_parameters)
|
||||
if not _is_object_schema(normalized_schema):
|
||||
normalized_schema = _build_parameters_schema_from_property_map(normalized_schema)
|
||||
return cls(
|
||||
name=name,
|
||||
description=description,
|
||||
params=None,
|
||||
parameters_schema_override=normalized_schema,
|
||||
)
|
||||
|
||||
if isinstance(raw_parameters, list):
|
||||
params: List[ToolParam] = []
|
||||
for raw_parameter in raw_parameters:
|
||||
if isinstance(raw_parameter, tuple) and len(raw_parameter) == 5:
|
||||
params.append(ToolParam.from_legacy_tuple(raw_parameter))
|
||||
continue
|
||||
if isinstance(raw_parameter, dict):
|
||||
parameter_name = str(raw_parameter.get("name", "") or "").strip()
|
||||
if not parameter_name:
|
||||
continue
|
||||
params.append(
|
||||
ToolParam.from_dict(
|
||||
parameter_name,
|
||||
raw_parameter,
|
||||
required=bool(raw_parameter.get("required", False)),
|
||||
)
|
||||
)
|
||||
return cls(
|
||||
name=name,
|
||||
description=description,
|
||||
params=params or None,
|
||||
parameters_schema_override=None,
|
||||
)
|
||||
|
||||
return cls(name=name, description=description, params=None, parameters_schema_override=None)
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> Dict[str, Any] | None:
|
||||
"""获取工具参数的对象级 JSON Schema。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any] | None: 工具参数 Schema。无参数工具时返回 `None`。
|
||||
"""
|
||||
if self.parameters_schema_override is not None:
|
||||
return deepcopy(self.parameters_schema_override)
|
||||
if not self.params:
|
||||
return None
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {param.name: param.to_json_schema() for param in self.params},
|
||||
"required": [param.name for param in self.params if param.required],
|
||||
}
|
||||
|
||||
def to_openai_function_schema(self) -> Dict[str, Any]:
|
||||
"""转换为 OpenAI function calling 结构。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: OpenAI 兼容的工具定义。
|
||||
"""
|
||||
function_schema: Dict[str, Any] = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
}
|
||||
if self.parameters_schema is not None:
|
||||
function_schema["parameters"] = self.parameters_schema
|
||||
return {
|
||||
"type": "function",
|
||||
"function": function_schema,
|
||||
}
|
||||
|
||||
|
||||
class ToolOptionBuilder:
|
||||
"""
|
||||
工具调用项构建器
|
||||
"""
|
||||
"""工具定义构建器。"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""初始化构建器。"""
|
||||
self.__name: str = ""
|
||||
self.__description: str = ""
|
||||
self.__params: list[ToolParam] = []
|
||||
self.__params: List[ToolParam] = []
|
||||
self.__parameters_schema_override: Dict[str, Any] | None = None
|
||||
|
||||
def set_name(self, name: str) -> "ToolOptionBuilder":
|
||||
"""
|
||||
设置工具名称
|
||||
:param name: 工具名称
|
||||
:return: ToolBuilder实例
|
||||
"""设置工具名称。
|
||||
|
||||
Args:
|
||||
name: 工具名称。
|
||||
|
||||
Returns:
|
||||
ToolOptionBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当名称为空时抛出。
|
||||
"""
|
||||
if not name:
|
||||
raise ValueError("工具名称不能为空")
|
||||
@@ -85,35 +370,76 @@ class ToolOptionBuilder:
|
||||
return self
|
||||
|
||||
def set_description(self, description: str) -> "ToolOptionBuilder":
|
||||
"""
|
||||
设置工具描述
|
||||
:param description: 工具描述
|
||||
:return: ToolBuilder实例
|
||||
"""设置工具描述。
|
||||
|
||||
Args:
|
||||
description: 工具描述。
|
||||
|
||||
Returns:
|
||||
ToolOptionBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当描述为空时抛出。
|
||||
"""
|
||||
if not description:
|
||||
raise ValueError("工具描述不能为空")
|
||||
self.__description = description
|
||||
return self
|
||||
|
||||
def set_parameters_schema(self, schema: Dict[str, Any]) -> "ToolOptionBuilder":
|
||||
"""直接设置完整的参数对象 Schema。
|
||||
|
||||
Args:
|
||||
schema: 完整的对象级 JSON Schema。
|
||||
|
||||
Returns:
|
||||
ToolOptionBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当 schema 不是 object 类型时抛出。
|
||||
"""
|
||||
if schema.get("type") != "object":
|
||||
raise ValueError("工具参数 Schema 必须是 object 类型")
|
||||
self.__parameters_schema_override = deepcopy(schema)
|
||||
self.__params.clear()
|
||||
return self
|
||||
|
||||
def add_param(
|
||||
self,
|
||||
name: str,
|
||||
param_type: ToolParamType,
|
||||
description: str,
|
||||
required: bool = False,
|
||||
enum_values: list[str] | None = None,
|
||||
enum_values: List[Any] | None = None,
|
||||
*,
|
||||
items_schema: Dict[str, Any] | None = None,
|
||||
properties: Dict[str, Dict[str, Any]] | None = None,
|
||||
required_properties: List[str] | None = None,
|
||||
additional_properties: bool | Dict[str, Any] | None = None,
|
||||
default: Any = None,
|
||||
) -> "ToolOptionBuilder":
|
||||
"""
|
||||
添加工具参数
|
||||
:param name: 参数名称
|
||||
:param param_type: 参数类型
|
||||
:param description: 参数描述
|
||||
:param required: 是否必填(默认为False)
|
||||
:return: ToolBuilder实例
|
||||
"""
|
||||
if not name or not description:
|
||||
raise ValueError("参数名称/描述不能为空")
|
||||
"""添加一个参数定义。
|
||||
|
||||
Args:
|
||||
name: 参数名称。
|
||||
param_type: 参数类型。
|
||||
description: 参数描述。
|
||||
required: 参数是否必填。
|
||||
enum_values: 可选的枚举值列表。
|
||||
items_schema: 数组参数的元素 Schema。
|
||||
properties: 对象参数的属性定义。
|
||||
required_properties: 对象参数内部的必填字段。
|
||||
additional_properties: 对象参数是否允许额外字段。
|
||||
default: 参数默认值。
|
||||
|
||||
Returns:
|
||||
ToolOptionBuilder: 当前构建器实例。
|
||||
|
||||
Raises:
|
||||
ValueError: 当构建器已经设置完整 Schema 时抛出。
|
||||
"""
|
||||
if self.__parameters_schema_override is not None:
|
||||
raise ValueError("已设置完整参数 Schema,不能再逐项添加参数")
|
||||
self.__params.append(
|
||||
ToolParam(
|
||||
name=name,
|
||||
@@ -121,43 +447,83 @@ class ToolOptionBuilder:
|
||||
description=description,
|
||||
required=required,
|
||||
enum_values=enum_values,
|
||||
items_schema=deepcopy(items_schema),
|
||||
properties=deepcopy(properties),
|
||||
required_properties=list(required_properties or []),
|
||||
additional_properties=deepcopy(additional_properties),
|
||||
default=deepcopy(default),
|
||||
)
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
"""
|
||||
构建工具调用项
|
||||
:return: 工具调用项
|
||||
"""
|
||||
if self.__name == "" or self.__description == "":
|
||||
raise ValueError("工具名称/描述不能为空")
|
||||
def build(self) -> ToolOption:
|
||||
"""构建工具定义。
|
||||
|
||||
Returns:
|
||||
ToolOption: 构建完成的工具定义。
|
||||
|
||||
Raises:
|
||||
ValueError: 当工具名称或描述缺失时抛出。
|
||||
"""
|
||||
if not self.__name or not self.__description:
|
||||
raise ValueError("工具名称和描述不能为空")
|
||||
return ToolOption(
|
||||
name=self.__name,
|
||||
description=self.__description,
|
||||
params=None if len(self.__params) == 0 else self.__params,
|
||||
params=None if not self.__params else list(self.__params),
|
||||
parameters_schema_override=deepcopy(self.__parameters_schema_override),
|
||||
)
|
||||
|
||||
|
||||
class ToolCall:
|
||||
"""
|
||||
来自模型反馈的工具调用
|
||||
"""
|
||||
ToolDefinitionInput: TypeAlias = ToolOption | Dict[str, Any]
|
||||
"""统一的工具定义输入类型。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call_id: str,
|
||||
func_name: str,
|
||||
args: dict | None = None,
|
||||
):
|
||||
|
||||
def normalize_tool_option(tool_definition: ToolDefinitionInput) -> ToolOption:
|
||||
"""将任意支持的工具输入规范化为内部 `ToolOption`。
|
||||
|
||||
Args:
|
||||
tool_definition: 原始工具定义输入。
|
||||
|
||||
Returns:
|
||||
ToolOption: 规范化后的工具定义对象。
|
||||
"""
|
||||
if isinstance(tool_definition, ToolOption):
|
||||
return tool_definition
|
||||
return ToolOption.from_definition(tool_definition)
|
||||
|
||||
|
||||
def normalize_tool_options(
|
||||
tool_definitions: List[ToolDefinitionInput] | None,
|
||||
) -> List[ToolOption] | None:
|
||||
"""批量规范化工具定义列表。
|
||||
|
||||
Args:
|
||||
tool_definitions: 原始工具定义列表。
|
||||
|
||||
Returns:
|
||||
List[ToolOption] | None: 规范化后的工具列表;输入为空时返回 `None`。
|
||||
"""
|
||||
if not tool_definitions:
|
||||
return None
|
||||
return [normalize_tool_option(tool_definition) for tool_definition in tool_definitions]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class ToolCall:
|
||||
"""来自模型输出的工具调用。"""
|
||||
|
||||
call_id: str
|
||||
func_name: str
|
||||
args: Dict[str, Any] | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""执行工具调用的基础校验。
|
||||
|
||||
Raises:
|
||||
ValueError: 当工具调用标识或函数名缺失时抛出。
|
||||
"""
|
||||
初始化工具调用
|
||||
:param call_id: 工具调用ID
|
||||
:param func_name: 要调用的函数名称
|
||||
:param args: 工具调用参数
|
||||
"""
|
||||
self.call_id: str = call_id
|
||||
self.func_name: str = func_name
|
||||
self.args: dict | None = args
|
||||
if not self.call_id:
|
||||
raise ValueError("工具调用 ID 不能为空")
|
||||
if not self.func_name:
|
||||
raise ValueError("工具函数名称不能为空")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user