feat: Add LLM Provider support in plugin runtime
- Introduced LLM Provider declarations in plugin manifests, allowing plugins to specify their LLM capabilities. - Implemented validation for LLM Provider declarations to prevent duplicates and conflicts. - Enhanced the PluginRunner to handle LLM Provider invocation requests, enabling plugins to interact with LLM Providers seamlessly. - Added a ClientRegistry to manage LLM Provider registrations and ensure no conflicts arise between different plugins. - Created a PluginLLMClient to facilitate communication with LLM Providers through the plugin runtime. - Developed tests to ensure proper registration and conflict handling of LLM Providers.
This commit is contained in:
101
pytests/test_llm_provider_registry.py
Normal file
101
pytests/test_llm_provider_registry.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from typing import List
|
||||
|
||||
from src.llm_models.model_client.base_client import (
|
||||
APIResponse,
|
||||
AudioTranscriptionRequest,
|
||||
BaseClient,
|
||||
ClientProviderRegistration,
|
||||
ClientRegistry,
|
||||
EmbeddingRequest,
|
||||
ResponseRequest,
|
||||
)
|
||||
|
||||
|
||||
class DummyClient(BaseClient):
|
||||
"""测试用 LLM 客户端。"""
|
||||
|
||||
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
||||
"""获取测试响应。
|
||||
|
||||
Args:
|
||||
request: 统一响应请求。
|
||||
|
||||
Returns:
|
||||
APIResponse: 测试响应。
|
||||
"""
|
||||
del request
|
||||
return APIResponse(content="ok")
|
||||
|
||||
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
||||
"""获取测试嵌入。
|
||||
|
||||
Args:
|
||||
request: 统一嵌入请求。
|
||||
|
||||
Returns:
|
||||
APIResponse: 测试嵌入响应。
|
||||
"""
|
||||
del request
|
||||
return APIResponse(embedding=[1.0])
|
||||
|
||||
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
||||
"""获取测试音频转写。
|
||||
|
||||
Args:
|
||||
request: 统一音频转写请求。
|
||||
|
||||
Returns:
|
||||
APIResponse: 测试音频转写响应。
|
||||
"""
|
||||
del request
|
||||
return APIResponse(content="audio")
|
||||
|
||||
def get_support_image_formats(self) -> List[str]:
|
||||
"""获取测试支持的图片格式。
|
||||
|
||||
Returns:
|
||||
List[str]: 支持的图片格式列表。
|
||||
"""
|
||||
return ["png"]
|
||||
|
||||
|
||||
def test_client_registry_rejects_provider_conflict():
|
||||
"""同一 client_type 被不同插件注册时应拒绝。"""
|
||||
registry = ClientRegistry()
|
||||
registry.replace_plugin_providers(
|
||||
"plugin.alpha",
|
||||
[
|
||||
ClientProviderRegistration(
|
||||
client_type="example",
|
||||
factory=DummyClient,
|
||||
owner_plugin_id="plugin.alpha",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
try:
|
||||
registry.validate_plugin_provider_replacement("plugin.beta", ["example"])
|
||||
except ValueError as exc:
|
||||
assert "冲突" in str(exc)
|
||||
else:
|
||||
raise AssertionError("不同插件注册相同 client_type 应失败")
|
||||
|
||||
|
||||
def test_client_registry_unregisters_plugin_providers():
|
||||
"""插件注销时应移除它拥有的 Provider 注册。"""
|
||||
registry = ClientRegistry()
|
||||
registry.replace_plugin_providers(
|
||||
"plugin.alpha",
|
||||
[
|
||||
ClientProviderRegistration(
|
||||
client_type="example",
|
||||
factory=DummyClient,
|
||||
owner_plugin_id="plugin.alpha",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
removed_count = registry.unregister_plugin_providers("plugin.alpha")
|
||||
|
||||
assert removed_count == 1
|
||||
assert "example" not in registry.client_registry
|
||||
@@ -30,6 +30,7 @@ def build_test_manifest(
|
||||
name: str = "测试插件",
|
||||
description: str = "测试插件描述",
|
||||
dependencies: list[dict[str, str]] | None = None,
|
||||
llm_providers: list[dict[str, str]] | None = None,
|
||||
capabilities: list[str] | None = None,
|
||||
host_min_version: str = "0.12.0",
|
||||
host_max_version: str = "1.0.0",
|
||||
@@ -44,6 +45,7 @@ def build_test_manifest(
|
||||
name: 展示名称。
|
||||
description: 插件描述。
|
||||
dependencies: 依赖声明列表。
|
||||
llm_providers: LLM Provider 静态声明列表。
|
||||
capabilities: 能力声明列表。
|
||||
host_min_version: Host 最低支持版本。
|
||||
host_max_version: Host 最高支持版本。
|
||||
@@ -75,6 +77,7 @@ def build_test_manifest(
|
||||
"max_version": sdk_max_version,
|
||||
},
|
||||
"dependencies": dependencies or [],
|
||||
"llm_providers": llm_providers or [],
|
||||
"capabilities": capabilities or [],
|
||||
"i18n": {
|
||||
"default_locale": "zh-CN",
|
||||
@@ -89,6 +92,7 @@ def build_test_manifest_model(
|
||||
*,
|
||||
version: str = "1.0.0",
|
||||
dependencies: list[dict[str, str]] | None = None,
|
||||
llm_providers: list[dict[str, str]] | None = None,
|
||||
capabilities: list[str] | None = None,
|
||||
host_version: str = "1.0.0",
|
||||
sdk_version: str = "2.0.1",
|
||||
@@ -99,6 +103,7 @@ def build_test_manifest_model(
|
||||
plugin_id: 插件 ID。
|
||||
version: 插件版本。
|
||||
dependencies: 依赖声明列表。
|
||||
llm_providers: LLM Provider 静态声明列表。
|
||||
capabilities: 能力声明列表。
|
||||
host_version: 当前测试使用的 Host 版本。
|
||||
sdk_version: 当前测试使用的 SDK 版本。
|
||||
@@ -114,6 +119,7 @@ def build_test_manifest_model(
|
||||
plugin_id,
|
||||
version=version,
|
||||
dependencies=dependencies,
|
||||
llm_providers=llm_providers,
|
||||
capabilities=capabilities,
|
||||
)
|
||||
)
|
||||
@@ -1055,6 +1061,63 @@ class TestManifestValidator:
|
||||
assert validator.validate(manifest) is False
|
||||
assert any("Python 包依赖冲突" in error for error in validator.errors)
|
||||
|
||||
def test_llm_provider_manifest_declaration(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
|
||||
manifest = build_test_manifest(
|
||||
"test.llm-provider",
|
||||
llm_providers=[
|
||||
{
|
||||
"client_type": "example.provider",
|
||||
"name": "Example Provider",
|
||||
"description": "测试 Provider",
|
||||
"version": "1.0.0",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
parsed_manifest = validator.parse_manifest(manifest)
|
||||
|
||||
assert parsed_manifest is not None
|
||||
assert parsed_manifest.llm_provider_client_types == ["example.provider"]
|
||||
|
||||
def test_duplicate_llm_provider_manifest_declaration_is_rejected(self):
|
||||
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||
|
||||
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
|
||||
manifest = build_test_manifest(
|
||||
"test.llm-provider-duplicate",
|
||||
llm_providers=[
|
||||
{"client_type": "example.provider"},
|
||||
{"client_type": "example.provider"},
|
||||
],
|
||||
)
|
||||
|
||||
assert validator.validate(manifest) is False
|
||||
assert any("重复的 LLM Provider" in error for error in validator.errors)
|
||||
|
||||
|
||||
def test_llm_provider_conflict_blocks_all_conflicting_plugins(tmp_path: Path):
|
||||
from src.plugin_runtime.integration import PluginRuntimeManager
|
||||
|
||||
plugin_root = tmp_path / "plugins"
|
||||
plugin_root.mkdir()
|
||||
for plugin_id in ["test.provider-alpha", "test.provider-beta"]:
|
||||
plugin_dir = plugin_root / plugin_id
|
||||
plugin_dir.mkdir()
|
||||
manifest = build_test_manifest(
|
||||
plugin_id,
|
||||
llm_providers=[{"client_type": "example.provider"}],
|
||||
)
|
||||
(plugin_dir / "_manifest.json").write_text(json.dumps(manifest), encoding="utf-8")
|
||||
(plugin_dir / "plugin.py").write_text("def create_plugin():\n return None\n", encoding="utf-8")
|
||||
|
||||
blocked_reasons = PluginRuntimeManager._discover_llm_provider_conflicts([plugin_root])
|
||||
|
||||
assert set(blocked_reasons) == {"test.provider-alpha", "test.provider-beta"}
|
||||
assert all("example.provider" in reason for reason in blocked_reasons.values())
|
||||
|
||||
|
||||
class TestVersionComparator:
|
||||
"""版本号比较器测试"""
|
||||
|
||||
Reference in New Issue
Block a user