feat: Enhance OpenAI compatibility and introduce unified LLM service data models

- Refactored model fetching logic to support various authentication methods for OpenAI-compatible APIs.
- Introduced new data models for LLM service requests and responses to standardize interactions across layers.
- Added an adapter base class for unified request execution across different providers.
- Implemented utility functions for building OpenAI-compatible client configurations and request overrides.
This commit is contained in:
DrSmoothl
2026-03-26 16:15:42 +08:00
parent 6e7daae55d
commit 777d4cb0d2
48 changed files with 5443 additions and 2945 deletions

View File

@@ -0,0 +1,259 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Coroutine, Generic, Tuple, TypeVar, cast
import asyncio
from src.config.model_configs import ModelInfo
from .base_client import (
APIResponse,
AudioTranscriptionRequest,
BaseClient,
EmbeddingRequest,
ResponseRequest,
UsageRecord,
UsageTuple,
)
RawStreamT = TypeVar("RawStreamT")
"""流式原始响应类型变量。"""
RawResponseT = TypeVar("RawResponseT")
"""非流式原始响应类型变量。"""
TaskResultT = TypeVar("TaskResultT")
"""异步任务返回值类型变量。"""
ProviderStreamResponseHandler = Callable[
[RawStreamT, asyncio.Event | None],
Coroutine[Any, Any, Tuple[APIResponse, UsageTuple | None]],
]
"""Provider 专用流式响应处理函数类型。"""
ProviderResponseParser = Callable[[RawResponseT], Tuple[APIResponse, UsageTuple | None]]
"""Provider 专用非流式响应解析函数类型。"""
async def await_task_with_interrupt(
task: asyncio.Task[TaskResultT],
interrupt_flag: asyncio.Event | None,
*,
interval_seconds: float = 0.1,
) -> TaskResultT:
"""在支持外部中断的前提下等待异步任务完成。
Args:
task: 待等待的异步任务。
interrupt_flag: 外部中断标记。
interval_seconds: 轮询检查间隔,单位秒。
Returns:
TaskResultT: 任务执行结果。
Raises:
ReqAbortException: 等待期间收到外部中断信号时抛出。
"""
from src.llm_models.exceptions import ReqAbortException
while not task.done():
if interrupt_flag and interrupt_flag.is_set():
task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(interval_seconds)
return await task
class AdapterClient(BaseClient, ABC, Generic[RawStreamT, RawResponseT]):
"""提供统一请求执行骨架的 Provider 适配基类。"""
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取对话响应。
Args:
request: 统一响应请求对象。
Returns:
APIResponse: 解析完成的统一响应对象。
"""
stream_response_handler = self._resolve_stream_response_handler(request)
response_parser = self._resolve_response_parser(request)
response, usage_record = await self._execute_response_request(
request,
stream_response_handler,
response_parser,
)
return self._attach_usage_record(response, request.model_info, usage_record)
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取文本嵌入。
Args:
request: 统一嵌入请求对象。
Returns:
APIResponse: 解析完成的统一嵌入响应。
"""
response, usage_record = await self._execute_embedding_request(request)
return self._attach_usage_record(response, request.model_info, usage_record)
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取音频转录。
Args:
request: 统一音频转录请求对象。
Returns:
APIResponse: 解析完成的统一音频转录响应。
"""
response, usage_record = await self._execute_audio_transcription_request(request)
return self._attach_usage_record(response, request.model_info, usage_record)
def _resolve_stream_response_handler(
self,
request: ResponseRequest,
) -> ProviderStreamResponseHandler[RawStreamT]:
"""解析实际使用的流式响应处理器。
Args:
request: 统一响应请求对象。
Returns:
ProviderStreamResponseHandler[RawStreamT]: 流式响应处理器。
"""
if request.stream_response_handler is not None:
return cast(ProviderStreamResponseHandler[RawStreamT], request.stream_response_handler)
return self._build_default_stream_response_handler(request)
def _resolve_response_parser(
self,
request: ResponseRequest,
) -> ProviderResponseParser[RawResponseT]:
"""解析实际使用的非流式响应解析器。
Args:
request: 统一响应请求对象。
Returns:
ProviderResponseParser[RawResponseT]: 非流式响应解析器。
"""
if request.async_response_parser is not None:
return cast(ProviderResponseParser[RawResponseT], request.async_response_parser)
return self._build_default_response_parser(request)
@staticmethod
def _build_usage_record(model_info: ModelInfo, usage_record: UsageTuple) -> UsageRecord:
"""根据统一使用量三元组构建 `UsageRecord`。
Args:
model_info: 模型信息。
usage_record: 使用量三元组。
Returns:
UsageRecord: 可直接挂载到 `APIResponse` 的使用记录对象。
"""
return UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
def _attach_usage_record(
self,
response: APIResponse,
model_info: ModelInfo,
usage_record: UsageTuple | None,
) -> APIResponse:
"""在响应对象上附加统一使用量信息。
Args:
response: 已解析的统一响应对象。
model_info: 模型信息。
usage_record: 可选的使用量三元组。
Returns:
APIResponse: 附加使用量后的响应对象。
"""
if usage_record is not None:
response.usage = self._build_usage_record(model_info, usage_record)
return response
@abstractmethod
def _build_default_stream_response_handler(
self,
request: ResponseRequest,
) -> ProviderStreamResponseHandler[RawStreamT]:
"""构建默认流式响应处理器。
Args:
request: 统一响应请求对象。
Returns:
ProviderStreamResponseHandler[RawStreamT]: 默认流式处理器。
"""
raise NotImplementedError
@abstractmethod
def _build_default_response_parser(
self,
request: ResponseRequest,
) -> ProviderResponseParser[RawResponseT]:
"""构建默认非流式响应解析器。
Args:
request: 统一响应请求对象。
Returns:
ProviderResponseParser[RawResponseT]: 默认非流式解析器。
"""
raise NotImplementedError
@abstractmethod
async def _execute_response_request(
self,
request: ResponseRequest,
stream_response_handler: ProviderStreamResponseHandler[RawStreamT],
response_parser: ProviderResponseParser[RawResponseT],
) -> Tuple[APIResponse, UsageTuple | None]:
"""执行 Provider 的文本/多模态响应请求。
Args:
request: 统一响应请求对象。
stream_response_handler: 流式响应处理器。
response_parser: 非流式响应解析器。
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
raise NotImplementedError
@abstractmethod
async def _execute_embedding_request(
self,
request: EmbeddingRequest,
) -> Tuple[APIResponse, UsageTuple | None]:
"""执行 Provider 的嵌入请求。
Args:
request: 统一嵌入请求对象。
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
raise NotImplementedError
@abstractmethod
async def _execute_audio_transcription_request(
self,
request: AudioTranscriptionRequest,
) -> Tuple[APIResponse, UsageTuple | None]:
"""执行 Provider 的音频转录请求。
Args:
request: 统一音频转录请求对象。
Returns:
Tuple[APIResponse, UsageTuple | None]: 统一响应对象与可选使用量信息。
"""
raise NotImplementedError

View File

@@ -1,14 +1,15 @@
import asyncio
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
import asyncio
from src.common.logger import get_logger
from src.config.config import config_manager
from src.config.model_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
from src.config.model_configs import APIProvider, ModelInfo
from src.llm_models.payload_content.message import Message
from src.llm_models.payload_content.resp_format import RespFormat
from src.llm_models.payload_content.tool_option import ToolCall, ToolOption
logger = get_logger("model_client_registry")
@@ -47,10 +48,10 @@ class APIResponse:
reasoning_content: str | None = None
"""推理内容"""
tool_calls: list[ToolCall] | None = None
tool_calls: List[ToolCall] | None = None
"""工具调用 [(工具名称, 工具参数), ...]"""
embedding: list[float] | None = None
embedding: List[float] | None = None
"""嵌入向量"""
usage: UsageRecord | None = None
@@ -60,6 +61,82 @@ class APIResponse:
"""响应原始数据"""
UsageTuple = Tuple[int, int, int]
"""统一的使用量三元组类型,顺序为 `(prompt_tokens, completion_tokens, total_tokens)`。"""
StreamResponseHandler = Callable[
[Any, asyncio.Event | None],
Coroutine[Any, Any, Tuple["APIResponse", UsageTuple | None]],
]
"""统一的流式响应处理函数类型。"""
ResponseParser = Callable[[Any], Tuple["APIResponse", UsageTuple | None]]
"""统一的非流式响应解析函数类型。"""
@dataclass(slots=True)
class ResponseRequest:
"""统一的文本/多模态响应请求。"""
model_info: ModelInfo
message_list: List[Message]
tool_options: List[ToolOption] | None = None
max_tokens: int | None = None
temperature: float | None = None
response_format: RespFormat | None = None
stream_response_handler: StreamResponseHandler | None = None
async_response_parser: ResponseParser | None = None
interrupt_flag: asyncio.Event | None = None
extra_params: Dict[str, Any] = field(default_factory=dict)
def copy_with(self, **changes: Any) -> "ResponseRequest":
"""基于当前请求创建一个带局部变更的新请求。
Args:
**changes: 需要覆盖的字段值。
Returns:
ResponseRequest: 复制后的请求对象。
"""
payload = {
"model_info": self.model_info,
"message_list": list(self.message_list),
"tool_options": None if self.tool_options is None else list(self.tool_options),
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"response_format": self.response_format,
"stream_response_handler": self.stream_response_handler,
"async_response_parser": self.async_response_parser,
"interrupt_flag": self.interrupt_flag,
"extra_params": dict(self.extra_params),
}
payload.update(changes)
return ResponseRequest(**payload)
@dataclass(slots=True)
class EmbeddingRequest:
"""统一的嵌入请求。"""
model_info: ModelInfo
embedding_input: str
extra_params: Dict[str, Any] = field(default_factory=dict)
@dataclass(slots=True)
class AudioTranscriptionRequest:
"""统一的音频转录请求。"""
model_info: ModelInfo
audio_base64: str
max_tokens: int | None = None
extra_params: Dict[str, Any] = field(default_factory=dict)
ClientRequest = ResponseRequest | EmbeddingRequest | AudioTranscriptionRequest
"""统一客户端请求类型。"""
class BaseClient(ABC):
"""
基础客户端
@@ -67,97 +144,82 @@ class BaseClient(ABC):
api_provider: APIProvider
def __init__(self, api_provider: APIProvider):
def __init__(self, api_provider: APIProvider) -> None:
"""初始化基础客户端。
Args:
api_provider: API 提供商配置。
"""
self.api_provider = api_provider
@abstractmethod
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param response_format: 响应格式(可选,默认为 NotGiven
:param stream_response_handler: 流式响应处理函数(可选)
:param async_response_parser: 响应解析函数(可选)
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取对话响应。
Args:
request: 统一响应请求对象。
Returns:
APIResponse: 统一响应对象。
"""
raise NotImplementedError("'get_response' method should be overridden in subclasses")
@abstractmethod
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取文本嵌入。
Args:
request: 统一嵌入请求对象。
Returns:
APIResponse: 嵌入响应。
"""
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
@abstractmethod
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
audio_base64: str,
max_tokens: Optional[int] = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: base64编码的音频数据
:extra_params: 附加的请求参数
:return: 音频转录响应
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取音频转录。
Args:
request: 统一音频转录请求对象。
Returns:
APIResponse: 音频转录响应。
"""
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
@abstractmethod
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
def get_support_image_formats(self) -> List[str]:
"""获取支持的图片格式。
Returns:
List[str]: 支持的图片格式列表。
"""
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
class ClientRegistry:
"""客户端注册表。"""
def __init__(self) -> None:
self.client_registry: dict[str, type[BaseClient]] = {}
"""初始化注册表并绑定配置重载回调。"""
self.client_registry: Dict[str, Type[BaseClient]] = {}
"""APIProvider.type -> BaseClient的映射表"""
self.client_instance_cache: dict[str, BaseClient] = {}
self.client_instance_cache: Dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
config_manager.register_reload_callback(self.clear_client_instance_cache)
def register_client_class(self, client_type: str):
"""
注册API客户端类
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
"""注册 API 客户端类。
Args:
client_class: API客户端类
client_type: 客户端类型标识。
Returns:
Callable[[Type[BaseClient]], Type[BaseClient]]: 装饰器函数。
"""
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
if not issubclass(cls, BaseClient):
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
self.client_registry[client_type] = cls
@@ -165,14 +227,15 @@ class ClientRegistry:
return decorator
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
"""
获取注册的API客户端实例
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
"""获取注册的 API 客户端实例。
Args:
api_provider: APIProvider实例
force_new: 是否强制创建新实例(用于解决事件循环问题)
api_provider: APIProvider 实例
force_new: 是否强制创建新实例
Returns:
BaseClient: 注册的API客户端实例
BaseClient: 注册的 API 客户端实例
"""
from . import ensure_client_type_loaded
@@ -194,6 +257,7 @@ class ClientRegistry:
return self.client_instance_cache[api_provider.name]
def clear_client_instance_cache(self) -> None:
"""清空客户端实例缓存。"""
self.client_instance_cache.clear()
logger.info("检测到配置重载已清空LLM客户端实例缓存")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,140 @@
from dataclasses import dataclass, field
from typing import Any, Mapping
from src.config.model_configs import APIProvider, OpenAICompatibleAuthType
@dataclass(slots=True)
class OpenAICompatibleClientConfig:
"""OpenAI 兼容客户端的基础配置。"""
api_key: str
base_url: str
default_headers: dict[str, str] = field(default_factory=dict)
default_query: dict[str, object] = field(default_factory=dict)
@dataclass(slots=True)
class OpenAICompatibleRequestOverrides:
"""单次请求级别的附加配置。"""
extra_headers: dict[str, str] = field(default_factory=dict)
extra_query: dict[str, object] = field(default_factory=dict)
extra_body: dict[str, Any] = field(default_factory=dict)
def normalize_openai_base_url(base_url: str) -> str:
"""规范化 OpenAI 兼容接口的基础地址。
Args:
base_url: 原始基础地址。
Returns:
str: 去掉尾部斜杠后的地址。
"""
return base_url.rstrip("/")
def _build_auth_header_value(prefix: str, api_key: str) -> str:
"""构造鉴权请求头的值。
Args:
prefix: 请求头前缀。
api_key: 实际密钥。
Returns:
str: 拼接完成的请求头值。
"""
normalized_prefix = prefix.strip()
if not normalized_prefix:
return api_key
return f"{normalized_prefix} {api_key}"
def build_openai_compatible_client_config(api_provider: APIProvider) -> OpenAICompatibleClientConfig:
"""构建 OpenAI 兼容客户端配置。
Args:
api_provider: API 提供商配置。
Returns:
OpenAICompatibleClientConfig: 可直接用于初始化 SDK 客户端的配置。
"""
default_headers = dict(api_provider.default_headers)
default_query: dict[str, object] = dict(api_provider.default_query)
client_api_key = api_provider.api_key
if api_provider.auth_type == OpenAICompatibleAuthType.BEARER:
if (
api_provider.auth_header_name != "Authorization"
or api_provider.auth_header_prefix.strip() != "Bearer"
):
client_api_key = ""
default_headers[api_provider.auth_header_name] = _build_auth_header_value(
prefix=api_provider.auth_header_prefix,
api_key=api_provider.api_key,
)
elif api_provider.auth_type == OpenAICompatibleAuthType.HEADER:
client_api_key = ""
default_headers[api_provider.auth_header_name] = _build_auth_header_value(
prefix=api_provider.auth_header_prefix,
api_key=api_provider.api_key,
)
elif api_provider.auth_type == OpenAICompatibleAuthType.QUERY:
client_api_key = ""
default_query[api_provider.auth_query_name] = api_provider.api_key
elif api_provider.auth_type == OpenAICompatibleAuthType.NONE:
client_api_key = ""
return OpenAICompatibleClientConfig(
api_key=client_api_key,
base_url=normalize_openai_base_url(api_provider.base_url),
default_headers=default_headers,
default_query=default_query,
)
def _extract_mapping(value: Any) -> dict[str, Any]:
"""将任意映射值规范化为普通字典。
Args:
value: 原始输入值。
Returns:
dict[str, Any]: 规范化后的字典。非映射值时返回空字典。
"""
if isinstance(value, Mapping):
return {str(key): item for key, item in value.items()}
return {}
def split_openai_request_overrides(
extra_params: Mapping[str, Any] | None,
*,
reserved_body_keys: set[str] | None = None,
) -> OpenAICompatibleRequestOverrides:
"""拆分单次请求中的头、查询参数和请求体扩展字段。
Args:
extra_params: 模型级别或请求级别的附加参数。
reserved_body_keys: 由 SDK 原生参数承载、因此不应再进入 `extra_body` 的字段集合。
Returns:
OpenAICompatibleRequestOverrides: 拆分后的请求覆盖配置。
"""
raw_params = dict(extra_params or {})
extra_headers = _extract_mapping(raw_params.pop("headers", None))
extra_query = _extract_mapping(raw_params.pop("query", None))
extra_body = _extract_mapping(raw_params.pop("body", None))
blocked_body_keys = reserved_body_keys or set()
for key, value in raw_params.items():
if key in blocked_body_keys:
continue
extra_body[key] = value
return OpenAICompatibleRequestOverrides(
extra_headers={key: str(value) for key, value in extra_headers.items()},
extra_query=extra_query,
extra_body=extra_body,
)

View File

@@ -1,133 +1,280 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional
from typing import List, Tuple
from .tool_option import ToolCall
# 设计这系列类的目的是为未来可能的扩展做准备
class RoleType(str, Enum):
"""消息角色类型。"""
class RoleType(Enum):
System = "system"
User = "user"
Assistant = "assistant"
Tool = "tool"
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"]
"""默认支持的图片格式列表。"""
@dataclass(slots=True)
class TextMessagePart:
"""文本消息片段。"""
text: str
def __post_init__(self) -> None:
"""执行文本片段的基础校验。
Raises:
ValueError: 当文本为空时抛出。
"""
if self.text == "":
raise ValueError("文本消息片段不能为空字符串")
@dataclass(slots=True)
class ImageMessagePart:
"""Base64 图片消息片段。"""
image_format: str
image_base64: str
def __post_init__(self) -> None:
"""执行图片片段的基础校验。
Raises:
ValueError: 当图片格式或 Base64 数据无效时抛出。
"""
if self.image_format.lower() not in SUPPORTED_IMAGE_FORMATS:
raise ValueError("不受支持的图片格式")
if not self.image_base64:
raise ValueError("图片的 base64 编码不能为空")
@property
def normalized_image_format(self) -> str:
"""获取规范化后的图片格式。
Returns:
str: 规范化后的图片格式。`jpg` 会被统一为 `jpeg`。
"""
image_format = self.image_format.lower()
if image_format in {"jpg", "jpeg"}:
return "jpeg"
return image_format
MessagePart = TextMessagePart | ImageMessagePart
@dataclass(slots=True)
class Message:
def __init__(
self,
role: RoleType,
content: str | list[tuple[str, str] | str],
tool_call_id: str | None = None,
tool_calls: Optional[List[ToolCall]] = None,
):
"""统一消息模型。"""
role: RoleType
parts: List[MessagePart] = field(default_factory=list)
tool_call_id: str | None = None
tool_calls: List[ToolCall] | None = None
def __post_init__(self) -> None:
"""执行消息对象的基础校验。
Raises:
ValueError: 当消息内容或工具调用信息不完整时抛出。
"""
初始化消息对象
不应直接修改Message类而应使用MessageBuilder类来构建对象
if not self.parts and not (self.role == RoleType.Assistant and self.tool_calls):
raise ValueError("消息内容不能为空")
if self.role == RoleType.Tool and not self.tool_call_id:
raise ValueError("Tool 角色的工具调用 ID 不能为空")
@property
def content(self) -> str | List[Tuple[str, str] | str]:
"""获取兼容旧逻辑的内容视图。
Returns:
str | List[Tuple[str, str] | str]: 当仅包含一个文本片段时返回字符串,
否则返回混合列表,其中图片片段表示为 `(format, base64)` 元组。
"""
self.role: RoleType = role
self.content: str | list[tuple[str, str] | str] = content
self.tool_call_id: str | None = tool_call_id
self.tool_calls: Optional[List[ToolCall]] = tool_calls
if len(self.parts) == 1 and isinstance(self.parts[0], TextMessagePart):
return self.parts[0].text
content_items: List[Tuple[str, str] | str] = []
for part in self.parts:
if isinstance(part, TextMessagePart):
content_items.append(part.text)
else:
content_items.append((part.image_format, part.image_base64))
return content_items
def get_text_content(self) -> str:
"""提取消息中的所有文本片段。
Returns:
str: 以原始顺序拼接后的文本内容。
"""
return "".join(part.text for part in self.parts if isinstance(part, TextMessagePart))
def __str__(self) -> str:
"""生成便于调试的字符串表示。
Returns:
str: 当前消息对象的可读摘要。
"""
return (
f"Role: {self.role}, Content: {self.content}, "
f"Role: {self.role}, Parts: {self.parts}, "
f"Tool Call ID: {self.tool_call_id}, Tool Calls: {self.tool_calls}"
)
class MessageBuilder:
def __init__(self):
"""消息构建器。"""
def __init__(self) -> None:
"""初始化构建器。"""
self.__role: RoleType = RoleType.User
self.__content: list[tuple[str, str] | str] = []
self.__parts: List[MessagePart] = []
self.__tool_call_id: str | None = None
self.__tool_calls: Optional[List[ToolCall]] = None
self.__tool_calls: List[ToolCall] | None = None
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
"""
设置角色默认为User
:param role: 角色
:return: MessageBuilder对象
"""设置消息角色。
Args:
role: 目标角色,默认为 `user`。
Returns:
MessageBuilder: 当前构建器实例。
"""
self.__role = role
return self
def add_text_part(self, text: str) -> "MessageBuilder":
"""追加文本片段。
Args:
text: 文本内容。
Returns:
MessageBuilder: 当前构建器实例。
"""
self.__parts.append(TextMessagePart(text=text))
return self
def add_text_content(self, text: str) -> "MessageBuilder":
"""追加文本片段。
Args:
text: 文本内容。
Returns:
MessageBuilder: 当前构建器实例。
"""
添加文本内容
:param text: 文本内容
:return: MessageBuilder对象
return self.add_text_part(text)
def add_image_base64_part(
self,
image_format: str,
image_base64: str,
support_formats: List[str] = SUPPORTED_IMAGE_FORMATS,
) -> "MessageBuilder":
"""追加 Base64 图片片段。
Args:
image_format: 图片格式。
image_base64: 图片的 Base64 编码。
support_formats: 允许的图片格式列表。
Returns:
MessageBuilder: 当前构建器实例。
Raises:
ValueError: 当图片格式不被支持时抛出。
"""
self.__content.append(text)
if image_format.lower() not in support_formats:
raise ValueError("不受支持的图片格式")
self.__parts.append(ImageMessagePart(image_format=image_format, image_base64=image_base64))
return self
def add_image_content(
self,
image_format: str,
image_base64: str,
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
support_formats: List[str] = SUPPORTED_IMAGE_FORMATS,
) -> "MessageBuilder":
"""
添加图片内容
:param image_format: 图片格式
:param image_base64: 图片的base64编码
:return: MessageBuilder对象
"""
if image_format.lower() not in support_formats:
raise ValueError("不受支持的图片格式")
if not image_base64:
raise ValueError("图片的base64编码不能为空")
self.__content.append((image_format, image_base64))
return self
"""追加 Base64 图片片段。
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
Args:
image_format: 图片格式。
image_base64: 图片的 Base64 编码。
support_formats: 允许的图片格式列表。
Returns:
MessageBuilder: 当前构建器实例。
"""
添加工具调用指令调用时请确保已设置为Tool角色
:param tool_call_id: 工具调用指令的id
:return: MessageBuilder对象
return self.add_image_base64_part(
image_format=image_format,
image_base64=image_base64,
support_formats=support_formats,
)
def set_tool_call_id(self, tool_call_id: str) -> "MessageBuilder":
"""设置工具结果消息引用的工具调用 ID。
Args:
tool_call_id: 工具调用 ID。
Returns:
MessageBuilder: 当前构建器实例。
Raises:
ValueError: 当当前角色不是 `tool` 或 ID 为空时抛出。
"""
if self.__role != RoleType.Tool:
raise ValueError("仅当角色为Tool时才能添加工具调用ID")
raise ValueError("仅当角色为 Tool 时才能设置工具调用 ID")
if not tool_call_id:
raise ValueError("工具调用ID不能为空")
raise ValueError("工具调用 ID 不能为空")
self.__tool_call_id = tool_call_id
return self
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
"""设置工具结果消息引用的工具调用 ID。
Args:
tool_call_id: 工具调用 ID。
Returns:
MessageBuilder: 当前构建器实例。
"""
设置助手消息的工具调用列表
:param tool_calls: 工具调用列表
:return: MessageBuilder对象
return self.set_tool_call_id(tool_call_id)
def set_tool_calls(self, tool_calls: List[ToolCall]) -> "MessageBuilder":
"""设置助手消息中的工具调用列表。
Args:
tool_calls: 工具调用列表。
Returns:
MessageBuilder: 当前构建器实例。
Raises:
ValueError: 当当前角色不是 `assistant` 或列表为空时抛出。
"""
if self.__role != RoleType.Assistant:
raise ValueError("仅当角色为Assistant时才能设置工具调用列表")
raise ValueError("仅当角色为 Assistant 时才能设置工具调用列表")
if not tool_calls:
raise ValueError("工具调用列表不能为空")
self.__tool_calls = tool_calls
self.__tool_calls = list(tool_calls)
return self
def build(self) -> Message:
"""
构建消息对象
:return: Message对象
"""
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
raise ValueError("内容不能为空")
if self.__role == RoleType.Tool and self.__tool_call_id is None:
raise ValueError("Tool角色的工具调用ID不能为空")
"""构建消息对象。
Returns:
Message: 构建完成的消息对象。
"""
return Message(
role=self.__role,
content=(
self.__content[0]
if (len(self.__content) == 1 and isinstance(self.__content[0], str))
else self.__content
),
parts=list(self.__parts),
tool_call_id=self.__tool_call_id,
tool_calls=self.__tool_calls,
tool_calls=list(self.__tool_calls) if self.__tool_calls else None,
)

View File

@@ -1,51 +1,40 @@
from copy import deepcopy
from enum import Enum
from typing import Optional, Any
from typing import Any, Dict, List, Mapping, Optional, Type, cast
from pydantic import BaseModel
from typing_extensions import TypedDict, Required
from typing_extensions import Required, TypedDict
class RespFormatType(Enum):
TEXT = "text" # 文本
JSON_OBJ = "json_object" # JSON
JSON_SCHEMA = "json_schema" # JSON Schema
"""响应格式类型。"""
TEXT = "text"
JSON_OBJ = "json_object"
JSON_SCHEMA = "json_schema"
class JsonSchema(TypedDict, total=False):
"""内部使用的 JSON Schema 包装结构。"""
name: Required[str]
"""
The name of the response format.
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
of 64.
"""
description: Optional[str]
"""
A description of what the response format is for, used by the model to determine
how to respond in the format.
"""
schema: dict[str, object]
"""
The schema for the response format, described as a JSON Schema object. Learn how
to build JSON schemas [here](https://json-schema.org/).
"""
schema: Dict[str, Any]
strict: Optional[bool]
"""
Whether to enable strict schema adherence when generating the output. If set to
true, the model will always follow the exact schema defined in the `schema`
field. Only a subset of JSON Schema is supported when `strict` is `true`. To
learn more, read the
[Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
"""
def _json_schema_type_check(instance) -> str | None:
def _json_schema_type_check(instance: Mapping[str, Any]) -> str | None:
"""检查 JSON Schema 包装结构是否合法。
Args:
instance: 待检查的 JSON Schema 包装字典。
Returns:
str | None: 不合法时返回错误信息,合法时返回 `None`。
"""
if "name" not in instance:
return "schema必须包含'name'字段"
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
if not isinstance(instance["name"], str) or instance["name"].strip() == "":
return "schema的'name'字段必须是非空字符串"
if "description" in instance and (
not isinstance(instance["description"], str) or instance["description"].strip() == ""
@@ -53,164 +42,198 @@ def _json_schema_type_check(instance) -> str | None:
return "schema的'description'字段只能填入非空字符串"
if "schema" not in instance:
return "schema必须包含'schema'字段"
elif not isinstance(instance["schema"], dict):
if not isinstance(instance["schema"], dict):
return "schema的'schema'字段必须是字典详见https://json-schema.org/"
if "strict" in instance and not isinstance(instance["strict"], bool):
return "schema的'strict'字段只能填入布尔值"
return None
def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]:
"""
递归移除JSON Schema中的title字段
def _remove_title(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]:
"""递归移除 JSON Schema 中的 `title` 字段。
Args:
schema: 待处理的 Schema 树。
Returns:
Dict[str, Any] | List[Any]: 处理后的 Schema 树。
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
for index, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "title" in schema:
del schema["title"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
schema[index] = _remove_title(item)
return schema
if "title" in schema:
del schema["title"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
return schema
def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
"""
链接JSON Schema中的definitions字段
def _link_definitions(schema: Dict[str, Any]) -> Dict[str, Any]:
"""展开 Schema 中的本地 `$defs`/`$ref` 引用。
Args:
schema: 待处理的根 Schema。
Returns:
Dict[str, Any]: 展开后的 Schema。
"""
def link_definitions_recursive(
path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any]
) -> dict[str, Any]:
"""
递归链接JSON Schema中的definitions字段
:param path: 当前路径
:param sub_schema: 子Schema
:param defs: Schema定义集
:return:
path: str,
sub_schema: Dict[str, Any] | List[Any],
definitions: Dict[str, Any],
) -> Dict[str, Any] | List[Any]:
"""递归展开局部定义。
Args:
path: 当前递归路径。
sub_schema: 当前子 Schema。
definitions: 已收集的定义字典。
Returns:
Dict[str, Any] | List[Any]: 展开后的子 Schema。
"""
if isinstance(sub_schema, list):
# 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
else:
# 否则为字典
if "$defs" in sub_schema:
# 如果当前Schema有$def字段则将其添加到defs中
key_prefix = f"{path}/$defs/"
for key, value in sub_schema["$defs"].items():
def_key = key_prefix + key
if def_key not in defs:
defs[def_key] = value
del sub_schema["$defs"]
if "$ref" in sub_schema:
# 如果当前Schema有$ref字段则将其替换为defs中的定义
def_key = sub_schema["$ref"]
if def_key in defs:
sub_schema = defs[def_key]
else:
raise ValueError(f"Schema中引用的定义'{def_key}'不存在")
# 遍历键值对
for key, value in sub_schema.items():
if isinstance(value, (dict, list)):
# 如果当前值是字典或列表,则递归调用
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
for index, item in enumerate(sub_schema):
if isinstance(item, (dict, list)):
sub_schema[index] = link_definitions_recursive(f"{path}/{index}", item, definitions)
return sub_schema
if "$defs" in sub_schema:
key_prefix = f"{path}/$defs/"
defs_payload = cast(Dict[str, Any], sub_schema["$defs"])
for key, value in defs_payload.items():
definition_key = key_prefix + key
if definition_key not in definitions:
definitions[definition_key] = value
del sub_schema["$defs"]
if "$ref" in sub_schema:
definition_key = cast(str, sub_schema["$ref"])
if definition_key in definitions:
return definitions[definition_key]
raise ValueError(f"Schema中引用的定义'{definition_key}'不存在")
for key, value in sub_schema.items():
if isinstance(value, (dict, list)):
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, definitions)
return sub_schema
return link_definitions_recursive("#", schema, {})
return cast(Dict[str, Any], link_definitions_recursive("#", schema, {}))
def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
"""
递归移除JSON Schema中的$defs字段
def _remove_defs(schema: Dict[str, Any] | List[Any]) -> Dict[str, Any] | List[Any]:
"""递归移除 JSON Schema 中的 `$defs` 字段。
Args:
schema: 待处理的 Schema 树。
Returns:
Dict[str, Any] | List[Any]: 处理后的 Schema 树。
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
for index, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "$defs" in schema:
del schema["$defs"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
schema[index] = _remove_defs(item)
return schema
if "$defs" in schema:
del schema["$defs"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_defs(value)
return schema
class RespFormat:
"""
响应格式
"""
"""统一响应格式定义。"""
@staticmethod
def _generate_schema_from_model(schema):
json_schema = {
"name": schema.__name__,
"schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
def _generate_schema_from_model(schema_model: Type[BaseModel]) -> JsonSchema:
"""从 Pydantic 模型生成内部 JSON Schema 包装结构。
Args:
schema_model: Pydantic 模型类。
Returns:
JsonSchema: 内部统一 JSON Schema 包装结构。
"""
schema_tree = deepcopy(schema_model.model_json_schema())
json_schema: JsonSchema = {
"name": schema_model.__name__,
"schema": cast(
Dict[str, Any],
_remove_defs(_link_definitions(cast(Dict[str, Any], _remove_title(schema_tree)))),
),
"strict": False,
}
if schema.__doc__:
json_schema["description"] = schema.__doc__
if schema_model.__doc__:
json_schema["description"] = schema_model.__doc__
return json_schema
def __init__(
self,
format_type: RespFormatType = RespFormatType.TEXT,
schema: type | JsonSchema | None = None,
):
"""
响应格式
:param format_type: 响应格式类型(默认为文本)
:param schema: 模板类或JsonSchema仅当format_type为JSON Schema时有效
schema: Type[BaseModel] | JsonSchema | None = None,
) -> None:
"""初始化响应格式对象。
Args:
format_type: 响应格式类型。
schema: 模型类或 JSON Schema 包装结构,仅 `JSON_SCHEMA` 模式使用。
"""
self.format_type: RespFormatType = format_type
self.schema_source: Type[BaseModel] | JsonSchema | None = schema
self.schema: JsonSchema | None = None
if format_type == RespFormatType.JSON_SCHEMA:
if schema is None:
raise ValueError("当format_type为'JSON_SCHEMA'schema不能为空")
if isinstance(schema, dict):
if check_msg := _json_schema_type_check(schema):
raise ValueError(f"schema格式不正确{check_msg}")
if format_type != RespFormatType.JSON_SCHEMA:
return
if schema is None:
raise ValueError("当format_type为'JSON_SCHEMA'schema不能为空")
if isinstance(schema, dict):
if check_msg := _json_schema_type_check(schema):
raise ValueError(f"schema格式不正确{check_msg}")
self.schema = cast(JsonSchema, deepcopy(schema))
return
if isinstance(schema, type) and issubclass(schema, BaseModel):
try:
self.schema = self._generate_schema_from_model(schema)
except Exception as exc:
raise ValueError(
f"自动生成JSON Schema时发生异常请检查模型类{schema.__name__}的定义,详细信息:\n"
f"{schema.__name__}:\n"
) from exc
return
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
self.schema = schema
elif issubclass(schema, BaseModel):
try:
json_schema = self._generate_schema_from_model(schema)
def get_schema_object(self) -> Dict[str, Any] | None:
"""获取内部包装中的对象级 JSON Schema。
self.schema = json_schema
except Exception as e:
raise ValueError(
f"自动生成JSON Schema时发生异常请检查模型类{schema.__name__}的定义,详细信息:\n"
f"{schema.__name__}:\n"
) from e
else:
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
else:
self.schema = None
def to_dict(self):
Returns:
Dict[str, Any] | None: 对象级 JSON Schema不存在时返回 `None`。
"""
将响应格式转换为字典
:return: 字典
if self.schema is None:
return None
schema_payload = self.schema.get("schema")
if isinstance(schema_payload, dict):
return cast(Dict[str, Any], deepcopy(schema_payload))
return None
def to_dict(self) -> Dict[str, Any]:
"""将响应格式转换为字典。
Returns:
Dict[str, Any]: 序列化后的响应格式字典。
"""
if self.schema:
return {
"format_type": self.format_type.value,
"schema": self.schema,
}
else:
return {
"format_type": self.format_type.value,
}
return {
"format_type": self.format_type.value,
}

View File

@@ -1,83 +1,368 @@
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Tuple, TypeAlias, cast
class ToolParamType(Enum):
class ToolParamType(str, Enum):
"""工具参数类型。"""
STRING = "string"
INTEGER = "integer"
NUMBER = "number"
FLOAT = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
LegacyToolParameterTuple = Tuple[str, ToolParamType, str, bool, List[str] | None]
"""旧版工具参数元组格式。"""
def normalize_tool_param_type(raw_value: ToolParamType | str | None) -> ToolParamType:
"""将任意输入值规范化为内部工具参数类型。
Args:
raw_value: 原始参数类型值。
Returns:
ToolParamType: 规范化后的参数类型。未知值会回退为 `STRING`。
"""
工具调用参数类型
if isinstance(raw_value, ToolParamType):
return raw_value
normalized_value = str(raw_value or "").strip().lower()
if normalized_value in {"integer", "int"}:
return ToolParamType.INTEGER
if normalized_value in {"number", "float"}:
return ToolParamType.NUMBER
if normalized_value in {"boolean", "bool"}:
return ToolParamType.BOOLEAN
if normalized_value == "array":
return ToolParamType.ARRAY
if normalized_value == "object":
return ToolParamType.OBJECT
return ToolParamType.STRING
def _is_object_schema(schema: Dict[str, Any]) -> bool:
"""判断输入字典是否已经是对象级 JSON Schema。
Args:
schema: 待判断的字典。
Returns:
bool: 为对象级 JSON Schema 时返回 `True`。
"""
STRING = "string" # 字符串
INTEGER = "integer" # 整型
FLOAT = "float" # 浮点型
BOOLEAN = "bool" # 布尔型
return schema.get("type") == "object" or "properties" in schema or "required" in schema
def _build_parameters_schema_from_property_map(property_map: Dict[str, Any]) -> Dict[str, Any]:
"""将属性映射转换为对象级 JSON Schema。
Args:
property_map: 仅包含属性定义的映射。
Returns:
Dict[str, Any]: 对象级 JSON Schema。
"""
required_names: List[str] = []
normalized_properties: Dict[str, Any] = {}
for property_name, property_schema in property_map.items():
if not isinstance(property_schema, dict):
continue
property_schema_copy = deepcopy(property_schema)
is_required = bool(property_schema_copy.pop("required", False))
if is_required:
required_names.append(str(property_name))
normalized_properties[str(property_name)] = property_schema_copy
parameters_schema: Dict[str, Any] = {
"type": "object",
"properties": normalized_properties,
}
if required_names:
parameters_schema["required"] = required_names
return parameters_schema
@dataclass(slots=True)
class ToolParam:
"""
工具调用参数
"""
"""工具参数定义。"""
def __init__(
self,
name: str
param_type: ToolParamType
description: str
required: bool
enum_values: List[Any] | None = None
items_schema: Dict[str, Any] | None = None
properties: Dict[str, Dict[str, Any]] | None = None
required_properties: List[str] = field(default_factory=list)
additional_properties: bool | Dict[str, Any] | None = None
default: Any = None
def __post_init__(self) -> None:
"""执行参数定义的基础校验。
Raises:
ValueError: 当参数名称或复杂类型定义不合法时抛出。
"""
if not self.name:
raise ValueError("参数名称不能为空")
if self.param_type == ToolParamType.ARRAY and self.items_schema is None:
raise ValueError("数组参数必须提供 items_schema")
if self.param_type == ToolParamType.OBJECT and self.properties is None:
self.properties = {}
@classmethod
def from_legacy_tuple(cls, parameter: LegacyToolParameterTuple) -> "ToolParam":
"""从旧版五元组参数定义构建工具参数。
Args:
parameter: 旧版参数元组。
Returns:
ToolParam: 规范化后的工具参数对象。
"""
return cls(
name=parameter[0],
param_type=parameter[1],
description=parameter[2],
required=parameter[3],
enum_values=parameter[4],
)
@classmethod
def from_dict(
cls,
name: str,
param_type: ToolParamType,
description: str,
required: bool,
enum_values: list[str] | None = None,
):
parameter_schema: Dict[str, Any],
*,
required: bool = False,
) -> "ToolParam":
"""从属性级 JSON Schema 或结构化参数字典构建工具参数。
Args:
name: 参数名称。
parameter_schema: 参数对应的 Schema 或结构化定义。
required: 参数是否必填。
Returns:
ToolParam: 规范化后的工具参数对象。
"""
初始化工具调用参数
不应直接修改ToolParam类而应使用ToolOptionBuilder类来构建对象
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填
raw_required_properties = parameter_schema.get("required_properties")
if raw_required_properties is None and isinstance(parameter_schema.get("required"), list):
raw_required_properties = parameter_schema.get("required")
return cls(
name=name,
param_type=normalize_tool_param_type(parameter_schema.get("param_type") or parameter_schema.get("type")),
description=str(parameter_schema.get("description", "") or ""),
required=required,
enum_values=deepcopy(parameter_schema.get("enum_values") or parameter_schema.get("enum")),
items_schema=deepcopy(parameter_schema.get("items_schema") or parameter_schema.get("items")),
properties=deepcopy(parameter_schema.get("properties")),
required_properties=list(raw_required_properties or []),
additional_properties=deepcopy(
parameter_schema["additional_properties"]
if "additional_properties" in parameter_schema
else parameter_schema.get("additionalProperties")
),
default=deepcopy(parameter_schema.get("default")),
)
def to_json_schema(self) -> Dict[str, Any]:
"""将参数定义转换为 JSON Schema。
Returns:
Dict[str, Any]: 参数对应的 JSON Schema 片段。
"""
self.name: str = name
self.param_type: ToolParamType = param_type
self.description: str = description
self.required: bool = required
self.enum_values: list[str] | None = enum_values
schema: Dict[str, Any] = {
"type": self.param_type.value,
"description": self.description,
}
if self.enum_values:
schema["enum"] = list(self.enum_values)
if self.default is not None:
schema["default"] = deepcopy(self.default)
if self.param_type == ToolParamType.ARRAY and self.items_schema is not None:
schema["items"] = deepcopy(self.items_schema)
if self.param_type == ToolParamType.OBJECT:
schema["properties"] = deepcopy(self.properties or {})
if self.required_properties:
schema["required"] = list(self.required_properties)
if self.additional_properties is not None:
schema["additionalProperties"] = deepcopy(self.additional_properties)
return schema
@dataclass(slots=True)
class ToolOption:
"""
工具调用项
"""
"""工具定义。"""
def __init__(
self,
name: str,
description: str,
params: list[ToolParam] | None = None,
):
name: str
description: str
params: List[ToolParam] | None = None
parameters_schema_override: Dict[str, Any] | None = None
def __post_init__(self) -> None:
"""执行工具定义的基础校验。
Raises:
ValueError: 当工具名称、描述或参数 Schema 不合法时抛出。
"""
初始化工具调用项
不应直接修改ToolOption类而应使用ToolOptionBuilder类来构建对象
:param name: 工具名称
:param description: 工具描述
:param params: 工具参数列表
if not self.name:
raise ValueError("工具名称不能为空")
if not self.description:
raise ValueError("工具描述不能为空")
if self.parameters_schema_override is not None:
schema_type = self.parameters_schema_override.get("type")
if schema_type != "object":
raise ValueError("工具参数 Schema 必须是 object 类型")
@classmethod
def from_definition(cls, definition: Dict[str, Any]) -> "ToolOption":
"""从任意支持的工具定义字典构建内部工具对象。
支持以下输入形状:
- `{"name", "description", "parameters_schema"}`
- `{"name", "description", "parameters"}`
- OpenAI function tool`{"type": "function", "function": {...}}`
- 仅属性映射的对象参数定义:`{"query": {"type": "string"}}`
Args:
definition: 原始工具定义字典。
Returns:
ToolOption: 规范化后的工具定义对象。
Raises:
ValueError: 当工具定义缺少必要字段时抛出。
"""
self.name: str = name
self.description: str = description
self.params: list[ToolParam] | None = params
if definition.get("type") == "function" and isinstance(definition.get("function"), dict):
function_definition = cast(Dict[str, Any], definition["function"])
return cls.from_definition(
{
"name": function_definition.get("name", ""),
"description": function_definition.get("description", ""),
"parameters_schema": function_definition.get("parameters"),
}
)
name = str(definition.get("name", "") or "").strip()
description = str(definition.get("description", "") or "").strip()
if not name:
raise ValueError("工具定义缺少 name")
if not description:
description = f"工具 {name}"
parameters_schema = definition.get("parameters_schema")
if isinstance(parameters_schema, dict):
normalized_schema = deepcopy(parameters_schema)
if not _is_object_schema(normalized_schema):
normalized_schema = _build_parameters_schema_from_property_map(normalized_schema)
return cls(
name=name,
description=description,
params=None,
parameters_schema_override=normalized_schema,
)
raw_parameters = definition.get("parameters")
if isinstance(raw_parameters, dict):
normalized_schema = deepcopy(raw_parameters)
if not _is_object_schema(normalized_schema):
normalized_schema = _build_parameters_schema_from_property_map(normalized_schema)
return cls(
name=name,
description=description,
params=None,
parameters_schema_override=normalized_schema,
)
if isinstance(raw_parameters, list):
params: List[ToolParam] = []
for raw_parameter in raw_parameters:
if isinstance(raw_parameter, tuple) and len(raw_parameter) == 5:
params.append(ToolParam.from_legacy_tuple(raw_parameter))
continue
if isinstance(raw_parameter, dict):
parameter_name = str(raw_parameter.get("name", "") or "").strip()
if not parameter_name:
continue
params.append(
ToolParam.from_dict(
parameter_name,
raw_parameter,
required=bool(raw_parameter.get("required", False)),
)
)
return cls(
name=name,
description=description,
params=params or None,
parameters_schema_override=None,
)
return cls(name=name, description=description, params=None, parameters_schema_override=None)
@property
def parameters_schema(self) -> Dict[str, Any] | None:
"""获取工具参数的对象级 JSON Schema。
Returns:
Dict[str, Any] | None: 工具参数 Schema。无参数工具时返回 `None`。
"""
if self.parameters_schema_override is not None:
return deepcopy(self.parameters_schema_override)
if not self.params:
return None
return {
"type": "object",
"properties": {param.name: param.to_json_schema() for param in self.params},
"required": [param.name for param in self.params if param.required],
}
def to_openai_function_schema(self) -> Dict[str, Any]:
"""转换为 OpenAI function calling 结构。
Returns:
Dict[str, Any]: OpenAI 兼容的工具定义。
"""
function_schema: Dict[str, Any] = {
"name": self.name,
"description": self.description,
}
if self.parameters_schema is not None:
function_schema["parameters"] = self.parameters_schema
return {
"type": "function",
"function": function_schema,
}
class ToolOptionBuilder:
"""
工具调用项构建器
"""
"""工具定义构建器。"""
def __init__(self):
def __init__(self) -> None:
"""初始化构建器。"""
self.__name: str = ""
self.__description: str = ""
self.__params: list[ToolParam] = []
self.__params: List[ToolParam] = []
self.__parameters_schema_override: Dict[str, Any] | None = None
def set_name(self, name: str) -> "ToolOptionBuilder":
"""
设置工具名称
:param name: 工具名称
:return: ToolBuilder实例
"""设置工具名称。
Args:
name: 工具名称。
Returns:
ToolOptionBuilder: 当前构建器实例。
Raises:
ValueError: 当名称为空时抛出。
"""
if not name:
raise ValueError("工具名称不能为空")
@@ -85,35 +370,76 @@ class ToolOptionBuilder:
return self
def set_description(self, description: str) -> "ToolOptionBuilder":
"""
设置工具描述
:param description: 工具描述
:return: ToolBuilder实例
"""设置工具描述。
Args:
description: 工具描述。
Returns:
ToolOptionBuilder: 当前构建器实例。
Raises:
ValueError: 当描述为空时抛出。
"""
if not description:
raise ValueError("工具描述不能为空")
self.__description = description
return self
def set_parameters_schema(self, schema: Dict[str, Any]) -> "ToolOptionBuilder":
"""直接设置完整的参数对象 Schema。
Args:
schema: 完整的对象级 JSON Schema。
Returns:
ToolOptionBuilder: 当前构建器实例。
Raises:
ValueError: 当 schema 不是 object 类型时抛出。
"""
if schema.get("type") != "object":
raise ValueError("工具参数 Schema 必须是 object 类型")
self.__parameters_schema_override = deepcopy(schema)
self.__params.clear()
return self
def add_param(
self,
name: str,
param_type: ToolParamType,
description: str,
required: bool = False,
enum_values: list[str] | None = None,
enum_values: List[Any] | None = None,
*,
items_schema: Dict[str, Any] | None = None,
properties: Dict[str, Dict[str, Any]] | None = None,
required_properties: List[str] | None = None,
additional_properties: bool | Dict[str, Any] | None = None,
default: Any = None,
) -> "ToolOptionBuilder":
"""
添加工具参数
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填默认为False
:return: ToolBuilder实例
"""
if not name or not description:
raise ValueError("参数名称/描述不能为空")
"""添加一个参数定义。
Args:
name: 参数名称。
param_type: 参数类型。
description: 参数描述。
required: 参数是否必填。
enum_values: 可选的枚举值列表。
items_schema: 数组参数的元素 Schema。
properties: 对象参数的属性定义。
required_properties: 对象参数内部的必填字段。
additional_properties: 对象参数是否允许额外字段。
default: 参数默认值。
Returns:
ToolOptionBuilder: 当前构建器实例。
Raises:
ValueError: 当构建器已经设置完整 Schema 时抛出。
"""
if self.__parameters_schema_override is not None:
raise ValueError("已设置完整参数 Schema不能再逐项添加参数")
self.__params.append(
ToolParam(
name=name,
@@ -121,43 +447,83 @@ class ToolOptionBuilder:
description=description,
required=required,
enum_values=enum_values,
items_schema=deepcopy(items_schema),
properties=deepcopy(properties),
required_properties=list(required_properties or []),
additional_properties=deepcopy(additional_properties),
default=deepcopy(default),
)
)
return self
def build(self):
"""
构建工具调用项
:return: 工具调用项
"""
if self.__name == "" or self.__description == "":
raise ValueError("工具名称/描述不能为空")
def build(self) -> ToolOption:
"""构建工具定义。
Returns:
ToolOption: 构建完成的工具定义。
Raises:
ValueError: 当工具名称或描述缺失时抛出。
"""
if not self.__name or not self.__description:
raise ValueError("工具名称和描述不能为空")
return ToolOption(
name=self.__name,
description=self.__description,
params=None if len(self.__params) == 0 else self.__params,
params=None if not self.__params else list(self.__params),
parameters_schema_override=deepcopy(self.__parameters_schema_override),
)
class ToolCall:
"""
来自模型反馈的工具调用
"""
ToolDefinitionInput: TypeAlias = ToolOption | Dict[str, Any]
"""统一的工具定义输入类型。"""
def __init__(
self,
call_id: str,
func_name: str,
args: dict | None = None,
):
def normalize_tool_option(tool_definition: ToolDefinitionInput) -> ToolOption:
"""将任意支持的工具输入规范化为内部 `ToolOption`。
Args:
tool_definition: 原始工具定义输入。
Returns:
ToolOption: 规范化后的工具定义对象。
"""
if isinstance(tool_definition, ToolOption):
return tool_definition
return ToolOption.from_definition(tool_definition)
def normalize_tool_options(
tool_definitions: List[ToolDefinitionInput] | None,
) -> List[ToolOption] | None:
"""批量规范化工具定义列表。
Args:
tool_definitions: 原始工具定义列表。
Returns:
List[ToolOption] | None: 规范化后的工具列表;输入为空时返回 `None`。
"""
if not tool_definitions:
return None
return [normalize_tool_option(tool_definition) for tool_definition in tool_definitions]
@dataclass(slots=True)
class ToolCall:
"""来自模型输出的工具调用。"""
call_id: str
func_name: str
args: Dict[str, Any] | None = None
def __post_init__(self) -> None:
"""执行工具调用的基础校验。
Raises:
ValueError: 当工具调用标识或函数名缺失时抛出。
"""
初始化工具调用
:param call_id: 工具调用ID
:param func_name: 要调用的函数名称
:param args: 工具调用参数
"""
self.call_id: str = call_id
self.func_name: str = func_name
self.args: dict | None = args
if not self.call_id:
raise ValueError("工具调用 ID 不能为空")
if not self.func_name:
raise ValueError("工具函数名称不能为空")

File diff suppressed because it is too large Load Diff