diff --git a/pytests/config_test/test_config_manager_hot_reload.py b/pytests/config_test/test_config_manager_hot_reload.py index ab2dd898..a42a4133 100644 --- a/pytests/config_test/test_config_manager_hot_reload.py +++ b/pytests/config_test/test_config_manager_hot_reload.py @@ -16,7 +16,7 @@ async def test_handle_file_changes_throttles_reload(): called = 0 - async def reload_stub() -> bool: + async def reload_stub(changed_scopes=None) -> bool: nonlocal called called += 1 return True @@ -36,7 +36,7 @@ async def test_handle_file_changes_timeout_logged(caplog): manager._hot_reload_min_interval_s = 0.0 manager._hot_reload_timeout_s = 0.01 - async def reload_stub() -> bool: + async def reload_stub(changed_scopes=None) -> bool: await asyncio.sleep(0.05) return True @@ -55,7 +55,7 @@ async def test_handle_file_changes_empty_skips_reload(): called = 0 - async def reload_stub() -> bool: + async def reload_stub(changed_scopes=None) -> bool: nonlocal called called += 1 return True diff --git a/pytests/config_test/test_config_manager_startup_upgrade.py b/pytests/config_test/test_config_manager_startup_upgrade.py new file mode 100644 index 00000000..96d1f08e --- /dev/null +++ b/pytests/config_test/test_config_manager_startup_upgrade.py @@ -0,0 +1,33 @@ +from typing import Any + +import pytest + +from src.config import config as config_module +from src.config.config import Config, ConfigManager, ModelConfig + + +class _StartupUpgradeExit(Exception): + pass + + +def test_initialize_upgrades_bot_and_model_config_before_exit(monkeypatch): + manager = ConfigManager() + loaded_config_classes: list[type[Any]] = [] + exit_codes: list[int | None] = [] + + def fake_load_config_from_file(config_class, config_path, new_ver, override_repr=False): + loaded_config_classes.append(config_class) + return object(), True + + def fake_exit(code: int | None = None): + exit_codes.append(code) + raise _StartupUpgradeExit + + monkeypatch.setattr(config_module, "load_config_from_file", fake_load_config_from_file) + monkeypatch.setattr(config_module.sys, "exit", fake_exit) + + with pytest.raises(_StartupUpgradeExit): + manager.initialize() + + assert loaded_config_classes == [Config, ModelConfig] + assert exit_codes == [0] diff --git a/pytests/test_llm_provider_registry.py b/pytests/test_llm_provider_registry.py new file mode 100644 index 00000000..abc412ad --- /dev/null +++ b/pytests/test_llm_provider_registry.py @@ -0,0 +1,101 @@ +from typing import List + +from src.llm_models.model_client.base_client import ( + APIResponse, + AudioTranscriptionRequest, + BaseClient, + ClientProviderRegistration, + ClientRegistry, + EmbeddingRequest, + ResponseRequest, +) + + +class DummyClient(BaseClient): + """测试用 LLM 客户端。""" + + async def get_response(self, request: ResponseRequest) -> APIResponse: + """获取测试响应。 + + Args: + request: 统一响应请求。 + + Returns: + APIResponse: 测试响应。 + """ + del request + return APIResponse(content="ok") + + async def get_embedding(self, request: EmbeddingRequest) -> APIResponse: + """获取测试嵌入。 + + Args: + request: 统一嵌入请求。 + + Returns: + APIResponse: 测试嵌入响应。 + """ + del request + return APIResponse(embedding=[1.0]) + + async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse: + """获取测试音频转写。 + + Args: + request: 统一音频转写请求。 + + Returns: + APIResponse: 测试音频转写响应。 + """ + del request + return APIResponse(content="audio") + + def get_support_image_formats(self) -> List[str]: + """获取测试支持的图片格式。 + + Returns: + List[str]: 支持的图片格式列表。 + """ + return ["png"] + + +def test_client_registry_rejects_provider_conflict(): + """同一 client_type 被不同插件注册时应拒绝。""" + registry = ClientRegistry() + registry.replace_plugin_providers( + "plugin.alpha", + [ + ClientProviderRegistration( + client_type="example", + factory=DummyClient, + owner_plugin_id="plugin.alpha", + ) + ], + ) + + try: + registry.validate_plugin_provider_replacement("plugin.beta", ["example"]) + except ValueError as exc: + assert "冲突" in str(exc) + else: + raise AssertionError("不同插件注册相同 client_type 应失败") + + +def test_client_registry_unregisters_plugin_providers(): + """插件注销时应移除它拥有的 Provider 注册。""" + registry = ClientRegistry() + registry.replace_plugin_providers( + "plugin.alpha", + [ + ClientProviderRegistration( + client_type="example", + factory=DummyClient, + owner_plugin_id="plugin.alpha", + ) + ], + ) + + removed_count = registry.unregister_plugin_providers("plugin.alpha") + + assert removed_count == 1 + assert "example" not in registry.client_registry diff --git a/pytests/test_plugin_runtime.py b/pytests/test_plugin_runtime.py index d8561f19..095002e2 100644 --- a/pytests/test_plugin_runtime.py +++ b/pytests/test_plugin_runtime.py @@ -30,6 +30,7 @@ def build_test_manifest( name: str = "测试插件", description: str = "测试插件描述", dependencies: list[dict[str, str]] | None = None, + llm_providers: list[dict[str, str]] | None = None, capabilities: list[str] | None = None, host_min_version: str = "0.12.0", host_max_version: str = "1.0.0", @@ -44,6 +45,7 @@ def build_test_manifest( name: 展示名称。 description: 插件描述。 dependencies: 依赖声明列表。 + llm_providers: LLM Provider 静态声明列表。 capabilities: 能力声明列表。 host_min_version: Host 最低支持版本。 host_max_version: Host 最高支持版本。 @@ -75,6 +77,7 @@ def build_test_manifest( "max_version": sdk_max_version, }, "dependencies": dependencies or [], + "llm_providers": llm_providers or [], "capabilities": capabilities or [], "i18n": { "default_locale": "zh-CN", @@ -89,6 +92,7 @@ def build_test_manifest_model( *, version: str = "1.0.0", dependencies: list[dict[str, str]] | None = None, + llm_providers: list[dict[str, str]] | None = None, capabilities: list[str] | None = None, host_version: str = "1.0.0", sdk_version: str = "2.0.1", @@ -99,6 +103,7 @@ def build_test_manifest_model( plugin_id: 插件 ID。 version: 插件版本。 dependencies: 依赖声明列表。 + llm_providers: LLM Provider 静态声明列表。 capabilities: 能力声明列表。 host_version: 当前测试使用的 Host 版本。 sdk_version: 当前测试使用的 SDK 版本。 @@ -114,6 +119,7 @@ def build_test_manifest_model( plugin_id, version=version, dependencies=dependencies, + llm_providers=llm_providers, capabilities=capabilities, ) ) @@ -1055,6 +1061,63 @@ class TestManifestValidator: assert validator.validate(manifest) is False assert any("Python 包依赖冲突" in error for error in validator.errors) + def test_llm_provider_manifest_declaration(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest( + "test.llm-provider", + llm_providers=[ + { + "client_type": "example.provider", + "name": "Example Provider", + "description": "测试 Provider", + "version": "1.0.0", + } + ], + ) + + parsed_manifest = validator.parse_manifest(manifest) + + assert parsed_manifest is not None + assert parsed_manifest.llm_provider_client_types == ["example.provider"] + + def test_duplicate_llm_provider_manifest_declaration_is_rejected(self): + from src.plugin_runtime.runner.manifest_validator import ManifestValidator + + validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1") + manifest = build_test_manifest( + "test.llm-provider-duplicate", + llm_providers=[ + {"client_type": "example.provider"}, + {"client_type": "example.provider"}, + ], + ) + + assert validator.validate(manifest) is False + assert any("重复的 LLM Provider" in error for error in validator.errors) + + +def test_llm_provider_conflict_blocks_all_conflicting_plugins(tmp_path: Path): + from src.plugin_runtime.integration import PluginRuntimeManager + + plugin_root = tmp_path / "plugins" + plugin_root.mkdir() + for plugin_id in ["test.provider-alpha", "test.provider-beta"]: + plugin_dir = plugin_root / plugin_id + plugin_dir.mkdir() + manifest = build_test_manifest( + plugin_id, + llm_providers=[{"client_type": "example.provider"}], + ) + (plugin_dir / "_manifest.json").write_text(json.dumps(manifest), encoding="utf-8") + (plugin_dir / "plugin.py").write_text("def create_plugin():\n return None\n", encoding="utf-8") + + blocked_reasons = PluginRuntimeManager._discover_llm_provider_conflicts([plugin_root]) + + assert set(blocked_reasons) == {"test.provider-alpha", "test.provider-beta"} + assert all("example.provider" in reason for reason in blocked_reasons.values()) + class TestVersionComparator: """版本号比较器测试""" diff --git a/src/config/config.py b/src/config/config.py index 7a9a6d3a..36eac1a9 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -196,8 +196,15 @@ class ConfigManager: def initialize(self): logger.info(t("config.current_version", version=MMC_VERSION)) logger.info(t("config.loading")) - self.global_config = self.load_global_config() - self.model_config = self.load_model_config() + self.global_config, global_updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION) + self.model_config, model_updated = load_config_from_file( + ModelConfig, + self.model_config_path, + MODEL_CONFIG_VERSION, + True, + ) + if global_updated or model_updated: + sys.exit(0) # 配置已自动升级,退出一次让用户确认新配置后再启动 logger.info(t("config.loaded")) def load_global_config(self) -> Config: diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 6e16e759..1c4c6b29 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type +from typing import Any, Callable, Coroutine, Dict, List, Set, Tuple, Type import asyncio @@ -204,19 +204,48 @@ class BaseClient(ABC): raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses") +ClientFactory = Callable[[APIProvider], BaseClient] +"""根据 APIProvider 创建客户端实例的工厂函数。""" + + +@dataclass(slots=True) +class ClientProviderRegistration: + """LLM Provider 客户端类型注册信息。""" + + client_type: str + """客户端类型标识,对应模型配置中的 `api_providers[].client_type`。""" + + factory: ClientFactory + """客户端实例工厂。""" + + owner_plugin_id: str | None = None + """拥有该客户端类型的插件 ID;主程序内置类型为 ``None``。""" + + version: str = "1.0.0" + """Provider 实现版本。""" + + description: str = "" + """Provider 描述文本。""" + + builtin: bool = False + """是否为主程序内置 Provider。""" + + class ClientRegistry: """客户端注册表。""" def __init__(self) -> None: """初始化注册表并绑定配置重载回调。""" - self.client_registry: Dict[str, Type[BaseClient]] = {} - """APIProvider.type -> BaseClient的映射表""" + self.client_registry: Dict[str, ClientProviderRegistration] = {} + """APIProvider.client_type -> Provider 注册信息映射表。""" self.client_instance_cache: Dict[str, BaseClient] = {} """APIProvider.name -> BaseClient的映射表""" + self._owner_client_types: Dict[str, Set[str]] = {} + """插件 ID -> 该插件拥有的 client_type 集合。""" config_manager.register_reload_callback(self.clear_client_instance_cache) def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]: - """注册 API 客户端类。 + """注册主程序内置 API 客户端类。 Args: client_type: 客户端类型标识。 @@ -226,13 +255,180 @@ class ClientRegistry: """ def decorator(cls: Type[BaseClient]) -> Type[BaseClient]: + """将内置客户端类注册到全局客户端注册表。 + + Args: + cls: 待注册的客户端类。 + + Returns: + Type[BaseClient]: 原始客户端类。 + """ if not issubclass(cls, BaseClient): raise TypeError(f"{cls.__name__} is not a subclass of BaseClient") - self.client_registry[client_type] = cls + self.register_provider( + ClientProviderRegistration( + client_type=client_type, + factory=cls, + builtin=True, + ) + ) return cls return decorator + @staticmethod + def _normalize_client_type(client_type: str) -> str: + """规范化客户端类型标识。 + + Args: + client_type: 原始客户端类型标识。 + + Returns: + str: 去除首尾空白后的客户端类型标识。 + + Raises: + ValueError: 当客户端类型为空时抛出。 + """ + normalized_client_type = str(client_type or "").strip() + if not normalized_client_type: + raise ValueError("client_type 不能为空") + return normalized_client_type + + def register_provider(self, registration: ClientProviderRegistration) -> None: + """注册单个客户端类型。 + + Args: + registration: Provider 注册信息。 + + Raises: + ValueError: 当客户端类型冲突时抛出。 + """ + client_type = self._normalize_client_type(registration.client_type) + existing = self.client_registry.get(client_type) + if existing is not None and existing.owner_plugin_id != registration.owner_plugin_id: + raise ValueError( + f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册" + ) + + self.client_registry[client_type] = ClientProviderRegistration( + client_type=client_type, + factory=registration.factory, + owner_plugin_id=registration.owner_plugin_id, + version=registration.version, + description=registration.description, + builtin=registration.builtin, + ) + if registration.owner_plugin_id: + self._owner_client_types.setdefault(registration.owner_plugin_id, set()).add(client_type) + self.clear_client_instance_cache_by_client_type(client_type) + + def validate_plugin_provider_replacement(self, plugin_id: str, client_types: List[str]) -> None: + """校验插件 Provider 替换是否会造成运行时冲突。 + + Args: + plugin_id: 目标插件 ID。 + client_types: 插件即将注册的客户端类型列表。 + + Raises: + ValueError: 当客户端类型为空、重复或与其他 owner 冲突时抛出。 + """ + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id: + raise ValueError("plugin_id 不能为空") + + normalized_client_types = [self._normalize_client_type(client_type) for client_type in client_types] + duplicate_client_types = sorted( + { + client_type + for client_type in normalized_client_types + if normalized_client_types.count(client_type) > 1 + } + ) + if duplicate_client_types: + raise ValueError(f"插件 {normalized_plugin_id} 重复声明 LLM Provider: {', '.join(duplicate_client_types)}") + + for client_type in normalized_client_types: + existing = self.client_registry.get(client_type) + if existing is None or existing.owner_plugin_id == normalized_plugin_id: + continue + raise ValueError( + f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册" + ) + + def replace_plugin_providers( + self, + plugin_id: str, + registrations: List[ClientProviderRegistration], + ) -> None: + """原子替换一个插件拥有的全部 Provider 注册。 + + Args: + plugin_id: 目标插件 ID。 + registrations: 插件当前上报的 Provider 注册列表。 + + Raises: + ValueError: 当注册信息不合法或存在冲突时抛出。 + """ + normalized_plugin_id = str(plugin_id or "").strip() + self.validate_plugin_provider_replacement( + normalized_plugin_id, + [registration.client_type for registration in registrations], + ) + self.unregister_plugin_providers(normalized_plugin_id) + for registration in registrations: + self.register_provider( + ClientProviderRegistration( + client_type=registration.client_type, + factory=registration.factory, + owner_plugin_id=normalized_plugin_id, + version=registration.version, + description=registration.description, + builtin=False, + ) + ) + + def unregister_plugin_providers(self, plugin_id: str) -> int: + """注销一个插件拥有的全部 Provider 注册。 + + Args: + plugin_id: 目标插件 ID。 + + Returns: + int: 被注销的客户端类型数量。 + """ + normalized_plugin_id = str(plugin_id or "").strip() + if not normalized_plugin_id: + return 0 + + client_types = self._owner_client_types.pop(normalized_plugin_id, set()) + removed_count = 0 + for client_type in client_types: + registration = self.client_registry.get(client_type) + if registration is None or registration.owner_plugin_id != normalized_plugin_id: + continue + self.client_registry.pop(client_type, None) + self.clear_client_instance_cache_by_client_type(client_type) + removed_count += 1 + return removed_count + + def clear_client_instance_cache_by_client_type(self, client_type: str) -> None: + """清理指定客户端类型对应的客户端实例缓存。 + + Args: + client_type: 需要清理缓存的客户端类型。 + """ + normalized_client_type = str(client_type or "").strip() + if not normalized_client_type: + return + + stale_provider_names = [ + provider_name + for provider_name, client in self.client_instance_cache.items() + if client.api_provider.client_type == normalized_client_type + ] + for provider_name in stale_provider_names: + self.client_instance_cache.pop(provider_name, None) + def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient: """获取注册的 API 客户端实例。 @@ -249,15 +445,14 @@ class ClientRegistry: # 如果强制创建新实例,直接创建不使用缓存 if force_new: - if client_class := self.client_registry.get(api_provider.client_type): - return client_class(api_provider) - else: - raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") + if registration := self.client_registry.get(api_provider.client_type): + return registration.factory(api_provider) + raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") # 正常的缓存逻辑 if api_provider.name not in self.client_instance_cache: - if client_class := self.client_registry.get(api_provider.client_type): - self.client_instance_cache[api_provider.name] = client_class(api_provider) + if registration := self.client_registry.get(api_provider.client_type): + self.client_instance_cache[api_provider.name] = registration.factory(api_provider) else: raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册") return self.client_instance_cache[api_provider.name] diff --git a/src/llm_models/model_client/plugin_client.py b/src/llm_models/model_client/plugin_client.py new file mode 100644 index 00000000..9758534f --- /dev/null +++ b/src/llm_models/model_client/plugin_client.py @@ -0,0 +1,191 @@ +from typing import Any, Dict, List + +from src.config.model_configs import APIProvider +from src.llm_models.exceptions import RespParseException +from src.llm_models.model_client.base_client import ( + APIResponse, + AudioTranscriptionRequest, + BaseClient, + EmbeddingRequest, + ResponseRequest, + UsageRecord, +) +from src.llm_models.request_snapshot import ( + deserialize_tool_calls_snapshot, + serialize_audio_request_snapshot, + serialize_embedding_request_snapshot, + serialize_response_request_snapshot, +) + + +class PluginLLMClient(BaseClient): + """通过插件 Runner RPC 调用第三方 LLM Provider 的客户端代理。""" + + def __init__( + self, + api_provider: APIProvider, + supervisor: Any, + plugin_id: str, + client_type: str, + ) -> None: + """初始化插件 LLM Provider 客户端代理。 + + Args: + api_provider: 当前 API Provider 配置。 + supervisor: 拥有目标插件的 Supervisor。 + plugin_id: 目标插件 ID。 + client_type: 目标客户端类型。 + """ + super().__init__(api_provider) + self._supervisor = supervisor + self._plugin_id = plugin_id + self._client_type = client_type + + async def get_response(self, request: ResponseRequest) -> APIResponse: + """获取对话响应。 + + Args: + request: 统一响应请求对象。 + + Returns: + APIResponse: 统一响应对象。 + + Raises: + RespParseException: 插件返回内容无法转换为统一响应时抛出。 + """ + if request.stream_response_handler is not None or request.async_response_parser is not None: + raise RespParseException(message="插件 LLM Provider 暂不支持 Host 侧自定义流式处理器或响应解析器") + payload = serialize_response_request_snapshot(request) + result = await self._invoke_provider("response", payload) + return self._build_api_response(result, request.model_info.name, request.model_info.api_provider) + + async def get_embedding(self, request: EmbeddingRequest) -> APIResponse: + """获取文本嵌入。 + + Args: + request: 统一嵌入请求对象。 + + Returns: + APIResponse: 嵌入响应。 + """ + result = await self._invoke_provider("embedding", serialize_embedding_request_snapshot(request)) + return self._build_api_response(result, request.model_info.name, request.model_info.api_provider) + + async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse: + """获取音频转录。 + + Args: + request: 统一音频转录请求对象。 + + Returns: + APIResponse: 音频转录响应。 + """ + result = await self._invoke_provider("audio_transcription", serialize_audio_request_snapshot(request)) + return self._build_api_response(result, request.model_info.name, request.model_info.api_provider) + + def get_support_image_formats(self) -> List[str]: + """获取支持的图片格式。 + + Returns: + List[str]: 插件 Provider 默认接收的图片格式列表。 + """ + return ["jpeg", "jpg", "png", "webp"] + + def _build_api_provider_snapshot(self) -> Dict[str, Any]: + """构建可传给插件 Provider 的 API Provider 配置快照。 + + Returns: + Dict[str, Any]: 包含认证信息的 API Provider 配置字典。 + """ + return self.api_provider.model_dump(mode="json") + + async def _invoke_provider(self, operation: str, request_payload: Dict[str, Any]) -> Dict[str, Any]: + """调用插件 Provider。 + + Args: + operation: 请求操作类型。 + request_payload: 已序列化的内部请求。 + + Returns: + Dict[str, Any]: 插件返回的响应字典。 + + Raises: + RespParseException: 插件调用失败或返回格式不合法时抛出。 + """ + try: + response = await self._supervisor.invoke_llm_provider( + plugin_id=self._plugin_id, + client_type=self._client_type, + operation=operation, + request={ + **request_payload, + "api_provider": self._build_api_provider_snapshot(), + }, + timeout_ms=max(1000, int(self.api_provider.timeout) * 1000), + ) + except Exception as exc: + raise RespParseException(message=f"插件 LLM Provider RPC 调用失败: {exc}") from exc + if response.error: + raise RespParseException(message=str(response.error.get("message", "插件 LLM Provider 调用失败"))) + + payload = response.payload if isinstance(response.payload, dict) else {} + success = bool(payload.get("success", False)) + result = payload.get("result") + if not success: + raise RespParseException(message=str(result or "插件 LLM Provider 返回失败")) + if not isinstance(result, dict): + raise RespParseException(message="插件 LLM Provider 返回值必须是字典") + return result + + @staticmethod + def _build_usage_record(raw_usage: Any, model_name: str, provider_name: str) -> UsageRecord | None: + """从插件返回值恢复使用量记录。 + + Args: + raw_usage: 插件返回的使用量字段。 + model_name: 当前模型名称。 + provider_name: 当前 Provider 名称。 + + Returns: + UsageRecord | None: 可挂载到 APIResponse 的使用量记录;缺失时返回 ``None``。 + """ + if not isinstance(raw_usage, dict): + return None + return UsageRecord( + model_name=str(raw_usage.get("model_name") or model_name), + provider_name=str(raw_usage.get("provider_name") or provider_name), + prompt_tokens=int(raw_usage.get("prompt_tokens") or 0), + completion_tokens=int(raw_usage.get("completion_tokens") or 0), + total_tokens=int(raw_usage.get("total_tokens") or 0), + prompt_cache_hit_tokens=int(raw_usage.get("prompt_cache_hit_tokens") or 0), + prompt_cache_miss_tokens=int(raw_usage.get("prompt_cache_miss_tokens") or 0), + ) + + @staticmethod + def _build_api_response(result: Dict[str, Any], model_name: str, provider_name: str) -> APIResponse: + """从插件返回值恢复统一 APIResponse。 + + Args: + result: 插件返回的响应字典。 + model_name: 当前模型名称。 + provider_name: 当前 Provider 名称。 + + Returns: + APIResponse: 统一响应对象。 + """ + raw_embedding = result.get("embedding") + embedding = [float(item) for item in raw_embedding] if isinstance(raw_embedding, list) else None + content = result.get("content") + if not isinstance(content, str): + content = result.get("response") + reasoning_content = result.get("reasoning_content") + if not isinstance(reasoning_content, str): + reasoning_content = result.get("reasoning") + return APIResponse( + content=content if isinstance(content, str) else None, + reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None, + tool_calls=deserialize_tool_calls_snapshot(result.get("tool_calls")) or None, + embedding=embedding, + usage=PluginLLMClient._build_usage_record(result.get("usage"), model_name, provider_name), + raw_data=result.get("raw_data", result), + ) diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index dbd45e8e..cb381c82 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1072,3 +1072,65 @@ class TempMethodsLLMUtils: if provider.name == provider_name: return provider raise ValueError(f"未找到名为 '{provider_name}' 的API提供商") + + +class LLMRequest(LLMOrchestrator): + """兼容旧调用方的 LLM 请求入口。 + + 新代码应优先使用 ``LLMOrchestrator`` 或服务层 ``LLMServiceClient``; + 该类保留旧版 ``model_set=TaskConfig`` 的构造方式,并在运行时解析 + 对应的最新任务配置名称。 + """ + + def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: + """初始化旧版 LLM 请求对象。 + + Args: + model_set: 旧调用方传入的任务配置对象。 + request_type: 当前请求的业务类型标识。 + """ + self._task_config_name = self._resolve_task_config_name(model_set) + super().__init__(task_name=self._task_config_name, request_type=request_type) + + @staticmethod + def _build_task_config_signature(task_config: TaskConfig) -> Tuple[Any, ...]: + """构造任务配置签名。 + + Args: + task_config: 任务配置对象。 + + Returns: + Tuple[Any, ...]: 可用于匹配热重载前后任务配置的签名。 + """ + return ( + tuple(task_config.model_list), + task_config.max_tokens, + task_config.temperature, + task_config.slow_threshold, + task_config.selection_strategy, + ) + + @classmethod + def _resolve_task_config_name(cls, model_set: TaskConfig) -> str: + """根据旧版 TaskConfig 对象解析任务配置名称。 + + Args: + model_set: 旧调用方传入的任务配置对象。 + + Returns: + str: 对应 ``model_task_config`` 下的字段名。 + + Raises: + ValueError: 未能找到匹配任务配置时抛出。 + """ + target_signature = cls._build_task_config_signature(model_set) + model_task_config = config_manager.get_model_config().model_task_config + for attr_name in dir(model_task_config): + if attr_name.startswith("_"): + continue + attr_value = getattr(model_task_config, attr_name) + if not isinstance(attr_value, TaskConfig): + continue + if attr_value is model_set or cls._build_task_config_signature(attr_value) == target_signature: + return attr_name + raise ValueError("无法根据旧版 model_set 解析任务配置名称") diff --git a/src/plugin_runtime/host/supervisor.py b/src/plugin_runtime/host/supervisor.py index 4927a120..3e56944d 100644 --- a/src/plugin_runtime/host/supervisor.py +++ b/src/plugin_runtime/host/supervisor.py @@ -10,6 +10,8 @@ import sys from src.common.logger import get_logger from src.config.config import config_manager, global_config +from src.llm_models.model_client.base_client import ClientProviderRegistration, client_registry +from src.llm_models.model_client.plugin_client import PluginLLMClient from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager from src.platform_io.drivers import PluginPlatformDriver from src.platform_io.route_key_factory import RouteKeyFactory @@ -30,6 +32,7 @@ from src.plugin_runtime.protocol.envelope import ( HealthPayload, InspectPluginConfigPayload, InspectPluginConfigResultPayload, + LLMProviderInvokePayload, MessageGatewayStateUpdatePayload, MessageGatewayStateUpdateResultPayload, PROTOCOL_VERSION, @@ -417,6 +420,38 @@ class PluginRunnerSupervisor: timeout_ms=timeout_ms, ) + async def invoke_llm_provider( + self, + plugin_id: str, + client_type: str, + operation: str, + request: Optional[Dict[str, Any]] = None, + timeout_ms: int = 30000, + ) -> Envelope: + """调用插件声明的 LLM Provider。 + + Args: + plugin_id: 目标插件 ID。 + client_type: 目标客户端类型。 + operation: 请求操作类型。 + request: 已序列化的 LLM 请求。 + timeout_ms: RPC 超时时间,单位毫秒。 + + Returns: + Envelope: Runner 返回的响应信封。 + """ + payload = LLMProviderInvokePayload( + client_type=client_type, + operation=operation, + request=request or {}, + ) + return await self._rpc_server.send_request( + "plugin.invoke_llm_provider", + plugin_id=plugin_id, + payload=payload.model_dump(), + timeout_ms=timeout_ms, + ) + async def invoke_api( self, plugin_id: str, @@ -779,6 +814,22 @@ class PluginRunnerSupervisor: component_declarations = [component.model_dump() for component in payload.components] runtime_components, api_components = self._split_component_declarations(component_declarations) + try: + client_registry.validate_plugin_provider_replacement( + payload.plugin_id, + [provider.client_type for provider in payload.llm_providers], + ) + except Exception as exc: + logger.error(f"插件 {payload.plugin_id} LLM Provider 注册校验失败: {exc}") + return envelope.make_error_response( + ErrorCode.E_BAD_PAYLOAD.value, + str(exc), + details={ + "plugin_id": payload.plugin_id, + "llm_provider_count": len(payload.llm_providers), + }, + ) + try: registered_count = self._component_registry.register_plugin_components( payload.plugin_id, @@ -798,6 +849,24 @@ class PluginRunnerSupervisor: self._api_registry.remove_apis_by_plugin(payload.plugin_id) registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components) await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) + client_registry.replace_plugin_providers( + payload.plugin_id, + [ + ClientProviderRegistration( + client_type=provider.client_type, + factory=lambda api_provider, provider_client_type=provider.client_type: PluginLLMClient( + api_provider=api_provider, + supervisor=self, + plugin_id=payload.plugin_id, + client_type=provider_client_type, + ), + owner_plugin_id=payload.plugin_id, + version=provider.version, + description=provider.description or provider.name, + ) + for provider in payload.llm_providers + ], + ) self._registered_plugins[payload.plugin_id] = payload self._message_gateway_states[payload.plugin_id] = {} @@ -810,6 +879,7 @@ class PluginRunnerSupervisor: "message_gateways": len( self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False) ), + "llm_providers": len(payload.llm_providers), } ) @@ -829,6 +899,7 @@ class PluginRunnerSupervisor: removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id) removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id) + removed_llm_providers = client_registry.unregister_plugin_providers(payload.plugin_id) self._authorization.revoke_permission_token(payload.plugin_id) removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id) @@ -841,6 +912,7 @@ class PluginRunnerSupervisor: "reason": payload.reason, "removed_components": removed_components, "removed_apis": removed_apis, + "removed_llm_providers": removed_llm_providers, "removed_registration": removed_registration, } ) @@ -1505,6 +1577,8 @@ class PluginRunnerSupervisor: def _clear_runner_state(self) -> None: """清理当前 Runner 对应的 Host 侧注册状态。""" + for plugin_id in list(self._registered_plugins): + client_registry.unregister_plugin_providers(plugin_id) self._authorization.clear() self._api_registry.clear() self._component_registry.clear() diff --git a/src/plugin_runtime/integration.py b/src/plugin_runtime/integration.py index 1cb796fa..548136ae 100644 --- a/src/plugin_runtime/integration.py +++ b/src/plugin_runtime/integration.py @@ -148,6 +148,35 @@ class PluginRuntimeManager( validator = ManifestValidator(validate_python_package_dependencies=False) return validator.build_plugin_dependency_map(plugin_dirs) + @classmethod + def _discover_llm_provider_conflicts(cls, plugin_dirs: Iterable[Path]) -> Dict[str, str]: + """扫描插件 Manifest,发现 LLM Provider client_type 冲突。 + + Args: + plugin_dirs: 需要扫描的插件根目录集合。 + + Returns: + Dict[str, str]: 需要阻止加载的插件 ID 与原因映射。 + """ + validator = ManifestValidator(validate_python_package_dependencies=False) + provider_owners: Dict[str, List[str]] = {} + for _plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs, require_entrypoint=True): + for client_type in manifest.llm_provider_client_types: + provider_owners.setdefault(client_type, []).append(manifest.id) + + blocked_reasons: Dict[str, str] = {} + for client_type, plugin_ids in provider_owners.items(): + unique_plugin_ids = sorted(set(plugin_ids)) + if len(unique_plugin_ids) <= 1: + continue + reason = ( + f"LLM Provider client_type 冲突: {client_type} 被以下插件重复声明: " + f"{', '.join(unique_plugin_ids)}" + ) + for plugin_id in unique_plugin_ids: + blocked_reasons[plugin_id] = reason + return blocked_reasons + @classmethod def _build_group_start_order( cls, @@ -271,7 +300,11 @@ class PluginRuntimeManager( """ result = await self._plugin_dependency_pipeline.execute(plugin_dirs) - changed_plugin_ids = self._set_blocked_plugin_reasons(result.blocked_plugin_reasons) + blocked_plugin_reasons = { + **result.blocked_plugin_reasons, + **self._discover_llm_provider_conflicts(plugin_dirs), + } + changed_plugin_ids = self._set_blocked_plugin_reasons(blocked_plugin_reasons) return DependencySyncState( blocked_changed_plugin_ids=changed_plugin_ids, environment_changed=result.environment_changed, diff --git a/src/plugin_runtime/protocol/envelope.py b/src/plugin_runtime/protocol/envelope.py index 81d8ec33..58d6d73b 100644 --- a/src/plugin_runtime/protocol/envelope.py +++ b/src/plugin_runtime/protocol/envelope.py @@ -199,6 +199,21 @@ class ComponentDeclaration(BaseModel): """组件元数据""" +class LLMProviderDeclaration(BaseModel): + """单个 LLM Provider 声明。""" + + client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type") + """客户端类型标识。""" + name: str = Field(default="", description="Provider 展示名称") + """Provider 展示名称。""" + description: str = Field(default="", description="Provider 描述") + """Provider 描述。""" + version: str = Field(default="1.0.0", description="Provider 实现版本") + """Provider 实现版本。""" + metadata: Dict[str, Any] = Field(default_factory=dict, description="Provider 元数据") + """Provider 元数据。""" + + class RegisterPluginPayload(BaseModel): """插件组件注册请求载荷。 @@ -212,6 +227,8 @@ class RegisterPluginPayload(BaseModel): """插件版本""" components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表") """组件列表""" + llm_providers: List[LLMProviderDeclaration] = Field(default_factory=list, description="LLM Provider 声明列表") + """LLM Provider 声明列表。""" capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表") """所需能力列表""" dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表") @@ -254,6 +271,17 @@ class InvokeResultPayload(BaseModel): """返回值""" +class LLMProviderInvokePayload(BaseModel): + """plugin.invoke_llm_provider 请求 payload。""" + + client_type: str = Field(description="目标 LLM Provider 客户端类型") + """目标 LLM Provider 客户端类型。""" + operation: str = Field(description="请求操作类型") + """请求操作类型,如 response、embedding、audio_transcription。""" + request: Dict[str, Any] = Field(default_factory=dict, description="已序列化的 LLM 请求") + """已序列化的 LLM 请求。""" + + # ====== 能力调用消息 ====== class CapabilityRequestPayload(BaseModel): """cap.* 请求 payload(插件 -> Host 能力调用)""" diff --git a/src/plugin_runtime/runner/manifest_validator.py b/src/plugin_runtime/runner/manifest_validator.py index 92e97d14..92f0b315 100644 --- a/src/plugin_runtime/runner/manifest_validator.py +++ b/src/plugin_runtime/runner/manifest_validator.py @@ -7,7 +7,7 @@ from functools import lru_cache from importlib import metadata as importlib_metadata from pathlib import Path -from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union import json import re @@ -453,6 +453,38 @@ class PythonPackageDependencyDefinition(_StrictManifestModel): return value +class LLMProviderManifestDeclaration(_StrictManifestModel): + """插件 Manifest 中声明的 LLM Provider。""" + + client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type") + """客户端类型标识。""" + name: str = Field(default="", description="Provider 展示名称") + """Provider 展示名称。""" + description: str = Field(default="", description="Provider 描述") + """Provider 描述。""" + version: str = Field(default="1.0.0", description="Provider 实现版本") + """Provider 实现版本。""" + + @field_validator("client_type") + @classmethod + def _validate_client_type(cls, value: str) -> str: + """校验客户端类型标识。 + + Args: + value: 原始客户端类型标识。 + + Returns: + str: 合法的客户端类型标识。 + + Raises: + ValueError: 当客户端类型为空时抛出。 + """ + normalized_value = str(value or "").strip() + if not normalized_value: + raise ValueError("client_type 不能为空") + return normalized_value + + ManifestDependencyDefinition = Annotated[ Union[PluginDependencyDefinition, PythonPackageDependencyDefinition], Field(discriminator="type"), @@ -472,6 +504,10 @@ class PluginManifest(_StrictManifestModel): host_application: ManifestVersionRange = Field(description="Host 兼容区间") sdk: ManifestVersionRange = Field(description="SDK 兼容区间") dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明") + llm_providers: List[LLMProviderManifestDeclaration] = Field( + default_factory=list, + description="插件静态声明的 LLM Provider 列表", + ) capabilities: List[str] = Field(description="插件声明的能力请求") i18n: ManifestI18n = Field(description="国际化配置") id: str = Field(description="稳定插件 ID") @@ -567,6 +603,23 @@ class PluginManifest(_StrictManifestModel): return self + @model_validator(mode="after") + def _validate_llm_providers(self) -> "PluginManifest": + """校验 LLM Provider 静态声明集合。 + + Returns: + PluginManifest: 当前对象本身。 + + Raises: + ValueError: 当同一 Manifest 内重复声明 client_type 时抛出。 + """ + client_types: Set[str] = set() + for provider in self.llm_providers: + if provider.client_type in client_types: + raise ValueError(f"存在重复的 LLM Provider 声明: {provider.client_type}") + client_types.add(provider.client_type) + return self + @property def plugin_dependencies(self) -> List[PluginDependencyDefinition]: """返回插件级依赖列表。 @@ -598,6 +651,15 @@ class PluginManifest(_StrictManifestModel): """ return [dependency.id for dependency in self.plugin_dependencies] + @property + def llm_provider_client_types(self) -> List[str]: + """返回 Manifest 静态声明的 LLM Provider client_type 列表。 + + Returns: + List[str]: 当前插件声明的 LLM Provider client_type。 + """ + return [provider.client_type for provider in self.llm_providers] + class ManifestValidator: """严格的插件 Manifest v2 校验器。""" diff --git a/src/plugin_runtime/runner/plugin_loader.py b/src/plugin_runtime/runner/plugin_loader.py index 31a7f19b..95e6b814 100644 --- a/src/plugin_runtime/runner/plugin_loader.py +++ b/src/plugin_runtime/runner/plugin_loader.py @@ -54,6 +54,7 @@ class PluginMeta: self.capabilities_required = list(manifest.capabilities) self.dependencies: List[str] = list(manifest.plugin_dependency_ids) self.component_handlers: Dict[str, str] = {} + self.llm_provider_handlers: Dict[str, str] = {} class PluginLoader: diff --git a/src/plugin_runtime/runner/runner_main.py b/src/plugin_runtime/runner/runner_main.py index bb8c3f4c..a2e5f460 100644 --- a/src/plugin_runtime/runner/runner_main.py +++ b/src/plugin_runtime/runner/runner_main.py @@ -48,6 +48,8 @@ from src.plugin_runtime.protocol.envelope import ( InspectPluginConfigResultPayload, InvokePayload, InvokeResultPayload, + LLMProviderDeclaration, + LLMProviderInvokePayload, RegisterPluginPayload, ReloadPluginPayload, ReloadPluginResultPayload, @@ -891,6 +893,7 @@ class PluginRunner: self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke) self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke) + self._rpc_client.register_method("plugin.invoke_llm_provider", self._handle_llm_provider_invoke) self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke) self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke) self._rpc_client.register_method("plugin.health", self._handle_health) @@ -980,6 +983,7 @@ class PluginRunner: """ # 收集插件组件声明 components: List[ComponentDeclaration] = [] + llm_providers: List[LLMProviderDeclaration] = [] config_reload_subscriptions: List[str] = [] instance = meta.instance @@ -1016,11 +1020,43 @@ class PluginRunner: ) if hasattr(instance, "get_config_reload_subscriptions"): config_reload_subscriptions = list(instance.get_config_reload_subscriptions()) + if hasattr(instance, "get_llm_providers"): + meta.llm_provider_handlers.clear() + for provider_info in instance.get_llm_providers(): + if not isinstance(provider_info, dict): + continue + + client_type = str(provider_info.get("client_type", "") or "").strip() + raw_metadata = provider_info.get("metadata", {}) + provider_metadata = raw_metadata if isinstance(raw_metadata, dict) else {} + if client_type: + handler_name = str(provider_metadata.get("handler_name", client_type) or client_type).strip() + meta.llm_provider_handlers[client_type] = handler_name or client_type + + llm_providers.append( + LLMProviderDeclaration( + client_type=client_type, + name=str(provider_info.get("name", "") or "").strip(), + description=str(provider_info.get("description", "") or "").strip(), + version=str(provider_info.get("version", "1.0.0") or "1.0.0").strip() or "1.0.0", + metadata=provider_metadata, + ) + ) + + declared_client_types = sorted(meta.manifest.llm_provider_client_types) + registered_client_types = sorted(provider.client_type for provider in llm_providers) + if declared_client_types != registered_client_types: + logger.error( + f"插件 {meta.plugin_id} LLM Provider 声明不一致: " + f"manifest={declared_client_types}, code={registered_client_types}" + ) + return False reg_payload = RegisterPluginPayload( plugin_id=meta.plugin_id, plugin_version=meta.version, components=components, + llm_providers=llm_providers, capabilities_required=meta.capabilities_required, dependencies=meta.dependencies, config_reload_subscriptions=config_reload_subscriptions, @@ -1629,6 +1665,50 @@ class PluginRunner: resp_payload = InvokeResultPayload(success=False, result=str(e)) return envelope.make_response(payload=resp_payload.model_dump()) + async def _handle_llm_provider_invoke(self, envelope: Envelope) -> Envelope: + """处理 LLM Provider 调用请求。 + + Args: + envelope: RPC 请求信封。 + + Returns: + Envelope: 标准化后的 Provider 调用结果。 + """ + try: + invoke = LLMProviderInvokePayload.model_validate(envelope.payload) + except Exception as exc: + return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc)) + + plugin_id = envelope.plugin_id + meta = self._loader.get_plugin(plugin_id) + if meta is None: + return envelope.make_error_response( + ErrorCode.E_PLUGIN_NOT_FOUND.value, + f"插件 {plugin_id} 未加载", + ) + + handler_name = meta.llm_provider_handlers.get(invoke.client_type, "") + handler_method = getattr(meta.instance, handler_name, None) if handler_name else None + if handler_method is None or not callable(handler_method): + return envelope.make_error_response( + ErrorCode.E_METHOD_NOT_ALLOWED.value, + f"插件 {plugin_id} 未注册 LLM Provider: {invoke.client_type}", + ) + + try: + result = handler_method(operation=invoke.operation, request=invoke.request) + if inspect.isawaitable(result): + result = await result + resp_payload = InvokeResultPayload(success=True, result=result) + return envelope.make_response(payload=resp_payload.model_dump()) + except Exception as exc: + logger.error( + f"插件 {plugin_id} LLM Provider {invoke.client_type} 执行异常: {exc}", + exc_info=True, + ) + resp_payload = InvokeResultPayload(success=False, result=str(exc)) + return envelope.make_response(payload=resp_payload.model_dump()) + async def _handle_event_invoke(self, envelope: Envelope) -> Envelope: """处理 EventHandler 调用请求