大修LLMReq
This commit is contained in:
@@ -81,10 +81,7 @@ class BaseClient:
|
||||
tuple[APIResponse, tuple[int, int, int]],
|
||||
]
|
||||
| None = None,
|
||||
async_response_parser: Callable[
|
||||
[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]
|
||||
]
|
||||
| None = None,
|
||||
async_response_parser: Callable[[ChatCompletion], tuple[APIResponse, tuple[int, int, int]]] | None = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
@@ -114,3 +111,37 @@ class BaseClient:
|
||||
:return: 嵌入响应
|
||||
"""
|
||||
raise RuntimeError("This method should be overridden in subclasses")
|
||||
|
||||
|
||||
class ClientRegistry:
|
||||
def __init__(self) -> None:
|
||||
self.client_registry: dict[str, type[BaseClient]] = {}
|
||||
|
||||
def register_client_class(self, client_type: str):
|
||||
"""
|
||||
注册API客户端类
|
||||
:param client_class: API客户端类
|
||||
"""
|
||||
|
||||
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
|
||||
if not issubclass(cls, BaseClient):
|
||||
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
|
||||
self.client_registry[client_type] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def get_client_class(self, client_type: str) -> type[BaseClient]:
|
||||
"""
|
||||
获取注册的API客户端类
|
||||
Args:
|
||||
client_type: 客户端类型
|
||||
Returns:
|
||||
type[BaseClient]: 注册的API客户端类
|
||||
"""
|
||||
if client_type not in self.client_registry:
|
||||
raise KeyError(f"'{client_type}' 类型的 Client 未注册")
|
||||
return self.client_registry[client_type]
|
||||
|
||||
|
||||
client_registry = ClientRegistry()
|
||||
|
||||
Reference in New Issue
Block a user