Merge branch 'Mai-with-u:dev' into dev
This commit is contained in:
@@ -16,7 +16,7 @@ async def test_handle_file_changes_throttles_reload():
|
|||||||
|
|
||||||
called = 0
|
called = 0
|
||||||
|
|
||||||
async def reload_stub() -> bool:
|
async def reload_stub(changed_scopes=None) -> bool:
|
||||||
nonlocal called
|
nonlocal called
|
||||||
called += 1
|
called += 1
|
||||||
return True
|
return True
|
||||||
@@ -36,7 +36,7 @@ async def test_handle_file_changes_timeout_logged(caplog):
|
|||||||
manager._hot_reload_min_interval_s = 0.0
|
manager._hot_reload_min_interval_s = 0.0
|
||||||
manager._hot_reload_timeout_s = 0.01
|
manager._hot_reload_timeout_s = 0.01
|
||||||
|
|
||||||
async def reload_stub() -> bool:
|
async def reload_stub(changed_scopes=None) -> bool:
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -55,7 +55,7 @@ async def test_handle_file_changes_empty_skips_reload():
|
|||||||
|
|
||||||
called = 0
|
called = 0
|
||||||
|
|
||||||
async def reload_stub() -> bool:
|
async def reload_stub(changed_scopes=None) -> bool:
|
||||||
nonlocal called
|
nonlocal called
|
||||||
called += 1
|
called += 1
|
||||||
return True
|
return True
|
||||||
|
|||||||
33
pytests/config_test/test_config_manager_startup_upgrade.py
Normal file
33
pytests/config_test/test_config_manager_startup_upgrade.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.config import config as config_module
|
||||||
|
from src.config.config import Config, ConfigManager, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
class _StartupUpgradeExit(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_upgrades_bot_and_model_config_before_exit(monkeypatch):
|
||||||
|
manager = ConfigManager()
|
||||||
|
loaded_config_classes: list[type[Any]] = []
|
||||||
|
exit_codes: list[int | None] = []
|
||||||
|
|
||||||
|
def fake_load_config_from_file(config_class, config_path, new_ver, override_repr=False):
|
||||||
|
loaded_config_classes.append(config_class)
|
||||||
|
return object(), True
|
||||||
|
|
||||||
|
def fake_exit(code: int | None = None):
|
||||||
|
exit_codes.append(code)
|
||||||
|
raise _StartupUpgradeExit
|
||||||
|
|
||||||
|
monkeypatch.setattr(config_module, "load_config_from_file", fake_load_config_from_file)
|
||||||
|
monkeypatch.setattr(config_module.sys, "exit", fake_exit)
|
||||||
|
|
||||||
|
with pytest.raises(_StartupUpgradeExit):
|
||||||
|
manager.initialize()
|
||||||
|
|
||||||
|
assert loaded_config_classes == [Config, ModelConfig]
|
||||||
|
assert exit_codes == [0]
|
||||||
101
pytests/test_llm_provider_registry.py
Normal file
101
pytests/test_llm_provider_registry.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from src.llm_models.model_client.base_client import (
|
||||||
|
APIResponse,
|
||||||
|
AudioTranscriptionRequest,
|
||||||
|
BaseClient,
|
||||||
|
ClientProviderRegistration,
|
||||||
|
ClientRegistry,
|
||||||
|
EmbeddingRequest,
|
||||||
|
ResponseRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyClient(BaseClient):
|
||||||
|
"""测试用 LLM 客户端。"""
|
||||||
|
|
||||||
|
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
||||||
|
"""获取测试响应。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 统一响应请求。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
APIResponse: 测试响应。
|
||||||
|
"""
|
||||||
|
del request
|
||||||
|
return APIResponse(content="ok")
|
||||||
|
|
||||||
|
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
||||||
|
"""获取测试嵌入。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 统一嵌入请求。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
APIResponse: 测试嵌入响应。
|
||||||
|
"""
|
||||||
|
del request
|
||||||
|
return APIResponse(embedding=[1.0])
|
||||||
|
|
||||||
|
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
||||||
|
"""获取测试音频转写。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 统一音频转写请求。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
APIResponse: 测试音频转写响应。
|
||||||
|
"""
|
||||||
|
del request
|
||||||
|
return APIResponse(content="audio")
|
||||||
|
|
||||||
|
def get_support_image_formats(self) -> List[str]:
|
||||||
|
"""获取测试支持的图片格式。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 支持的图片格式列表。
|
||||||
|
"""
|
||||||
|
return ["png"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_registry_rejects_provider_conflict():
|
||||||
|
"""同一 client_type 被不同插件注册时应拒绝。"""
|
||||||
|
registry = ClientRegistry()
|
||||||
|
registry.replace_plugin_providers(
|
||||||
|
"plugin.alpha",
|
||||||
|
[
|
||||||
|
ClientProviderRegistration(
|
||||||
|
client_type="example",
|
||||||
|
factory=DummyClient,
|
||||||
|
owner_plugin_id="plugin.alpha",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
registry.validate_plugin_provider_replacement("plugin.beta", ["example"])
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "冲突" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("不同插件注册相同 client_type 应失败")
|
||||||
|
|
||||||
|
|
||||||
|
def test_client_registry_unregisters_plugin_providers():
|
||||||
|
"""插件注销时应移除它拥有的 Provider 注册。"""
|
||||||
|
registry = ClientRegistry()
|
||||||
|
registry.replace_plugin_providers(
|
||||||
|
"plugin.alpha",
|
||||||
|
[
|
||||||
|
ClientProviderRegistration(
|
||||||
|
client_type="example",
|
||||||
|
factory=DummyClient,
|
||||||
|
owner_plugin_id="plugin.alpha",
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
removed_count = registry.unregister_plugin_providers("plugin.alpha")
|
||||||
|
|
||||||
|
assert removed_count == 1
|
||||||
|
assert "example" not in registry.client_registry
|
||||||
@@ -30,6 +30,7 @@ def build_test_manifest(
|
|||||||
name: str = "测试插件",
|
name: str = "测试插件",
|
||||||
description: str = "测试插件描述",
|
description: str = "测试插件描述",
|
||||||
dependencies: list[dict[str, str]] | None = None,
|
dependencies: list[dict[str, str]] | None = None,
|
||||||
|
llm_providers: list[dict[str, str]] | None = None,
|
||||||
capabilities: list[str] | None = None,
|
capabilities: list[str] | None = None,
|
||||||
host_min_version: str = "0.12.0",
|
host_min_version: str = "0.12.0",
|
||||||
host_max_version: str = "1.0.0",
|
host_max_version: str = "1.0.0",
|
||||||
@@ -44,6 +45,7 @@ def build_test_manifest(
|
|||||||
name: 展示名称。
|
name: 展示名称。
|
||||||
description: 插件描述。
|
description: 插件描述。
|
||||||
dependencies: 依赖声明列表。
|
dependencies: 依赖声明列表。
|
||||||
|
llm_providers: LLM Provider 静态声明列表。
|
||||||
capabilities: 能力声明列表。
|
capabilities: 能力声明列表。
|
||||||
host_min_version: Host 最低支持版本。
|
host_min_version: Host 最低支持版本。
|
||||||
host_max_version: Host 最高支持版本。
|
host_max_version: Host 最高支持版本。
|
||||||
@@ -75,6 +77,7 @@ def build_test_manifest(
|
|||||||
"max_version": sdk_max_version,
|
"max_version": sdk_max_version,
|
||||||
},
|
},
|
||||||
"dependencies": dependencies or [],
|
"dependencies": dependencies or [],
|
||||||
|
"llm_providers": llm_providers or [],
|
||||||
"capabilities": capabilities or [],
|
"capabilities": capabilities or [],
|
||||||
"i18n": {
|
"i18n": {
|
||||||
"default_locale": "zh-CN",
|
"default_locale": "zh-CN",
|
||||||
@@ -89,6 +92,7 @@ def build_test_manifest_model(
|
|||||||
*,
|
*,
|
||||||
version: str = "1.0.0",
|
version: str = "1.0.0",
|
||||||
dependencies: list[dict[str, str]] | None = None,
|
dependencies: list[dict[str, str]] | None = None,
|
||||||
|
llm_providers: list[dict[str, str]] | None = None,
|
||||||
capabilities: list[str] | None = None,
|
capabilities: list[str] | None = None,
|
||||||
host_version: str = "1.0.0",
|
host_version: str = "1.0.0",
|
||||||
sdk_version: str = "2.0.1",
|
sdk_version: str = "2.0.1",
|
||||||
@@ -99,6 +103,7 @@ def build_test_manifest_model(
|
|||||||
plugin_id: 插件 ID。
|
plugin_id: 插件 ID。
|
||||||
version: 插件版本。
|
version: 插件版本。
|
||||||
dependencies: 依赖声明列表。
|
dependencies: 依赖声明列表。
|
||||||
|
llm_providers: LLM Provider 静态声明列表。
|
||||||
capabilities: 能力声明列表。
|
capabilities: 能力声明列表。
|
||||||
host_version: 当前测试使用的 Host 版本。
|
host_version: 当前测试使用的 Host 版本。
|
||||||
sdk_version: 当前测试使用的 SDK 版本。
|
sdk_version: 当前测试使用的 SDK 版本。
|
||||||
@@ -114,6 +119,7 @@ def build_test_manifest_model(
|
|||||||
plugin_id,
|
plugin_id,
|
||||||
version=version,
|
version=version,
|
||||||
dependencies=dependencies,
|
dependencies=dependencies,
|
||||||
|
llm_providers=llm_providers,
|
||||||
capabilities=capabilities,
|
capabilities=capabilities,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -1055,6 +1061,63 @@ class TestManifestValidator:
|
|||||||
assert validator.validate(manifest) is False
|
assert validator.validate(manifest) is False
|
||||||
assert any("Python 包依赖冲突" in error for error in validator.errors)
|
assert any("Python 包依赖冲突" in error for error in validator.errors)
|
||||||
|
|
||||||
|
def test_llm_provider_manifest_declaration(self):
|
||||||
|
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||||
|
|
||||||
|
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
|
||||||
|
manifest = build_test_manifest(
|
||||||
|
"test.llm-provider",
|
||||||
|
llm_providers=[
|
||||||
|
{
|
||||||
|
"client_type": "example.provider",
|
||||||
|
"name": "Example Provider",
|
||||||
|
"description": "测试 Provider",
|
||||||
|
"version": "1.0.0",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
parsed_manifest = validator.parse_manifest(manifest)
|
||||||
|
|
||||||
|
assert parsed_manifest is not None
|
||||||
|
assert parsed_manifest.llm_provider_client_types == ["example.provider"]
|
||||||
|
|
||||||
|
def test_duplicate_llm_provider_manifest_declaration_is_rejected(self):
|
||||||
|
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
|
||||||
|
|
||||||
|
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
|
||||||
|
manifest = build_test_manifest(
|
||||||
|
"test.llm-provider-duplicate",
|
||||||
|
llm_providers=[
|
||||||
|
{"client_type": "example.provider"},
|
||||||
|
{"client_type": "example.provider"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert validator.validate(manifest) is False
|
||||||
|
assert any("重复的 LLM Provider" in error for error in validator.errors)
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_provider_conflict_blocks_all_conflicting_plugins(tmp_path: Path):
|
||||||
|
from src.plugin_runtime.integration import PluginRuntimeManager
|
||||||
|
|
||||||
|
plugin_root = tmp_path / "plugins"
|
||||||
|
plugin_root.mkdir()
|
||||||
|
for plugin_id in ["test.provider-alpha", "test.provider-beta"]:
|
||||||
|
plugin_dir = plugin_root / plugin_id
|
||||||
|
plugin_dir.mkdir()
|
||||||
|
manifest = build_test_manifest(
|
||||||
|
plugin_id,
|
||||||
|
llm_providers=[{"client_type": "example.provider"}],
|
||||||
|
)
|
||||||
|
(plugin_dir / "_manifest.json").write_text(json.dumps(manifest), encoding="utf-8")
|
||||||
|
(plugin_dir / "plugin.py").write_text("def create_plugin():\n return None\n", encoding="utf-8")
|
||||||
|
|
||||||
|
blocked_reasons = PluginRuntimeManager._discover_llm_provider_conflicts([plugin_root])
|
||||||
|
|
||||||
|
assert set(blocked_reasons) == {"test.provider-alpha", "test.provider-beta"}
|
||||||
|
assert all("example.provider" in reason for reason in blocked_reasons.values())
|
||||||
|
|
||||||
|
|
||||||
class TestVersionComparator:
|
class TestVersionComparator:
|
||||||
"""版本号比较器测试"""
|
"""版本号比较器测试"""
|
||||||
|
|||||||
@@ -196,8 +196,15 @@ class ConfigManager:
|
|||||||
def initialize(self):
|
def initialize(self):
|
||||||
logger.info(t("config.current_version", version=MMC_VERSION))
|
logger.info(t("config.current_version", version=MMC_VERSION))
|
||||||
logger.info(t("config.loading"))
|
logger.info(t("config.loading"))
|
||||||
self.global_config = self.load_global_config()
|
self.global_config, global_updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION)
|
||||||
self.model_config = self.load_model_config()
|
self.model_config, model_updated = load_config_from_file(
|
||||||
|
ModelConfig,
|
||||||
|
self.model_config_path,
|
||||||
|
MODEL_CONFIG_VERSION,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
if global_updated or model_updated:
|
||||||
|
sys.exit(0) # 配置已自动升级,退出一次让用户确认新配置后再启动
|
||||||
logger.info(t("config.loaded"))
|
logger.info(t("config.loaded"))
|
||||||
|
|
||||||
def load_global_config(self) -> Config:
|
def load_global_config(self) -> Config:
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
|
from typing import Any, Callable, Coroutine, Dict, List, Set, Tuple, Type
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
@@ -204,19 +204,48 @@ class BaseClient(ABC):
|
|||||||
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
|
||||||
|
|
||||||
|
|
||||||
|
ClientFactory = Callable[[APIProvider], BaseClient]
|
||||||
|
"""根据 APIProvider 创建客户端实例的工厂函数。"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class ClientProviderRegistration:
|
||||||
|
"""LLM Provider 客户端类型注册信息。"""
|
||||||
|
|
||||||
|
client_type: str
|
||||||
|
"""客户端类型标识,对应模型配置中的 `api_providers[].client_type`。"""
|
||||||
|
|
||||||
|
factory: ClientFactory
|
||||||
|
"""客户端实例工厂。"""
|
||||||
|
|
||||||
|
owner_plugin_id: str | None = None
|
||||||
|
"""拥有该客户端类型的插件 ID;主程序内置类型为 ``None``。"""
|
||||||
|
|
||||||
|
version: str = "1.0.0"
|
||||||
|
"""Provider 实现版本。"""
|
||||||
|
|
||||||
|
description: str = ""
|
||||||
|
"""Provider 描述文本。"""
|
||||||
|
|
||||||
|
builtin: bool = False
|
||||||
|
"""是否为主程序内置 Provider。"""
|
||||||
|
|
||||||
|
|
||||||
class ClientRegistry:
|
class ClientRegistry:
|
||||||
"""客户端注册表。"""
|
"""客户端注册表。"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""初始化注册表并绑定配置重载回调。"""
|
"""初始化注册表并绑定配置重载回调。"""
|
||||||
self.client_registry: Dict[str, Type[BaseClient]] = {}
|
self.client_registry: Dict[str, ClientProviderRegistration] = {}
|
||||||
"""APIProvider.type -> BaseClient的映射表"""
|
"""APIProvider.client_type -> Provider 注册信息映射表。"""
|
||||||
self.client_instance_cache: Dict[str, BaseClient] = {}
|
self.client_instance_cache: Dict[str, BaseClient] = {}
|
||||||
"""APIProvider.name -> BaseClient的映射表"""
|
"""APIProvider.name -> BaseClient的映射表"""
|
||||||
|
self._owner_client_types: Dict[str, Set[str]] = {}
|
||||||
|
"""插件 ID -> 该插件拥有的 client_type 集合。"""
|
||||||
config_manager.register_reload_callback(self.clear_client_instance_cache)
|
config_manager.register_reload_callback(self.clear_client_instance_cache)
|
||||||
|
|
||||||
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
|
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
|
||||||
"""注册 API 客户端类。
|
"""注册主程序内置 API 客户端类。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client_type: 客户端类型标识。
|
client_type: 客户端类型标识。
|
||||||
@@ -226,13 +255,180 @@ class ClientRegistry:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
|
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
|
||||||
|
"""将内置客户端类注册到全局客户端注册表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: 待注册的客户端类。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Type[BaseClient]: 原始客户端类。
|
||||||
|
"""
|
||||||
if not issubclass(cls, BaseClient):
|
if not issubclass(cls, BaseClient):
|
||||||
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
|
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
|
||||||
self.client_registry[client_type] = cls
|
self.register_provider(
|
||||||
|
ClientProviderRegistration(
|
||||||
|
client_type=client_type,
|
||||||
|
factory=cls,
|
||||||
|
builtin=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_client_type(client_type: str) -> str:
|
||||||
|
"""规范化客户端类型标识。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_type: 原始客户端类型标识。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 去除首尾空白后的客户端类型标识。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当客户端类型为空时抛出。
|
||||||
|
"""
|
||||||
|
normalized_client_type = str(client_type or "").strip()
|
||||||
|
if not normalized_client_type:
|
||||||
|
raise ValueError("client_type 不能为空")
|
||||||
|
return normalized_client_type
|
||||||
|
|
||||||
|
def register_provider(self, registration: ClientProviderRegistration) -> None:
|
||||||
|
"""注册单个客户端类型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registration: Provider 注册信息。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当客户端类型冲突时抛出。
|
||||||
|
"""
|
||||||
|
client_type = self._normalize_client_type(registration.client_type)
|
||||||
|
existing = self.client_registry.get(client_type)
|
||||||
|
if existing is not None and existing.owner_plugin_id != registration.owner_plugin_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.client_registry[client_type] = ClientProviderRegistration(
|
||||||
|
client_type=client_type,
|
||||||
|
factory=registration.factory,
|
||||||
|
owner_plugin_id=registration.owner_plugin_id,
|
||||||
|
version=registration.version,
|
||||||
|
description=registration.description,
|
||||||
|
builtin=registration.builtin,
|
||||||
|
)
|
||||||
|
if registration.owner_plugin_id:
|
||||||
|
self._owner_client_types.setdefault(registration.owner_plugin_id, set()).add(client_type)
|
||||||
|
self.clear_client_instance_cache_by_client_type(client_type)
|
||||||
|
|
||||||
|
def validate_plugin_provider_replacement(self, plugin_id: str, client_types: List[str]) -> None:
|
||||||
|
"""校验插件 Provider 替换是否会造成运行时冲突。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 目标插件 ID。
|
||||||
|
client_types: 插件即将注册的客户端类型列表。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当客户端类型为空、重复或与其他 owner 冲突时抛出。
|
||||||
|
"""
|
||||||
|
normalized_plugin_id = str(plugin_id or "").strip()
|
||||||
|
if not normalized_plugin_id:
|
||||||
|
raise ValueError("plugin_id 不能为空")
|
||||||
|
|
||||||
|
normalized_client_types = [self._normalize_client_type(client_type) for client_type in client_types]
|
||||||
|
duplicate_client_types = sorted(
|
||||||
|
{
|
||||||
|
client_type
|
||||||
|
for client_type in normalized_client_types
|
||||||
|
if normalized_client_types.count(client_type) > 1
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if duplicate_client_types:
|
||||||
|
raise ValueError(f"插件 {normalized_plugin_id} 重复声明 LLM Provider: {', '.join(duplicate_client_types)}")
|
||||||
|
|
||||||
|
for client_type in normalized_client_types:
|
||||||
|
existing = self.client_registry.get(client_type)
|
||||||
|
if existing is None or existing.owner_plugin_id == normalized_plugin_id:
|
||||||
|
continue
|
||||||
|
raise ValueError(
|
||||||
|
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
|
||||||
|
)
|
||||||
|
|
||||||
|
def replace_plugin_providers(
|
||||||
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
registrations: List[ClientProviderRegistration],
|
||||||
|
) -> None:
|
||||||
|
"""原子替换一个插件拥有的全部 Provider 注册。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 目标插件 ID。
|
||||||
|
registrations: 插件当前上报的 Provider 注册列表。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当注册信息不合法或存在冲突时抛出。
|
||||||
|
"""
|
||||||
|
normalized_plugin_id = str(plugin_id or "").strip()
|
||||||
|
self.validate_plugin_provider_replacement(
|
||||||
|
normalized_plugin_id,
|
||||||
|
[registration.client_type for registration in registrations],
|
||||||
|
)
|
||||||
|
self.unregister_plugin_providers(normalized_plugin_id)
|
||||||
|
for registration in registrations:
|
||||||
|
self.register_provider(
|
||||||
|
ClientProviderRegistration(
|
||||||
|
client_type=registration.client_type,
|
||||||
|
factory=registration.factory,
|
||||||
|
owner_plugin_id=normalized_plugin_id,
|
||||||
|
version=registration.version,
|
||||||
|
description=registration.description,
|
||||||
|
builtin=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def unregister_plugin_providers(self, plugin_id: str) -> int:
|
||||||
|
"""注销一个插件拥有的全部 Provider 注册。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 目标插件 ID。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 被注销的客户端类型数量。
|
||||||
|
"""
|
||||||
|
normalized_plugin_id = str(plugin_id or "").strip()
|
||||||
|
if not normalized_plugin_id:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
client_types = self._owner_client_types.pop(normalized_plugin_id, set())
|
||||||
|
removed_count = 0
|
||||||
|
for client_type in client_types:
|
||||||
|
registration = self.client_registry.get(client_type)
|
||||||
|
if registration is None or registration.owner_plugin_id != normalized_plugin_id:
|
||||||
|
continue
|
||||||
|
self.client_registry.pop(client_type, None)
|
||||||
|
self.clear_client_instance_cache_by_client_type(client_type)
|
||||||
|
removed_count += 1
|
||||||
|
return removed_count
|
||||||
|
|
||||||
|
def clear_client_instance_cache_by_client_type(self, client_type: str) -> None:
|
||||||
|
"""清理指定客户端类型对应的客户端实例缓存。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_type: 需要清理缓存的客户端类型。
|
||||||
|
"""
|
||||||
|
normalized_client_type = str(client_type or "").strip()
|
||||||
|
if not normalized_client_type:
|
||||||
|
return
|
||||||
|
|
||||||
|
stale_provider_names = [
|
||||||
|
provider_name
|
||||||
|
for provider_name, client in self.client_instance_cache.items()
|
||||||
|
if client.api_provider.client_type == normalized_client_type
|
||||||
|
]
|
||||||
|
for provider_name in stale_provider_names:
|
||||||
|
self.client_instance_cache.pop(provider_name, None)
|
||||||
|
|
||||||
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
|
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
|
||||||
"""获取注册的 API 客户端实例。
|
"""获取注册的 API 客户端实例。
|
||||||
|
|
||||||
@@ -249,15 +445,14 @@ class ClientRegistry:
|
|||||||
|
|
||||||
# 如果强制创建新实例,直接创建不使用缓存
|
# 如果强制创建新实例,直接创建不使用缓存
|
||||||
if force_new:
|
if force_new:
|
||||||
if client_class := self.client_registry.get(api_provider.client_type):
|
if registration := self.client_registry.get(api_provider.client_type):
|
||||||
return client_class(api_provider)
|
return registration.factory(api_provider)
|
||||||
else:
|
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
|
||||||
|
|
||||||
# 正常的缓存逻辑
|
# 正常的缓存逻辑
|
||||||
if api_provider.name not in self.client_instance_cache:
|
if api_provider.name not in self.client_instance_cache:
|
||||||
if client_class := self.client_registry.get(api_provider.client_type):
|
if registration := self.client_registry.get(api_provider.client_type):
|
||||||
self.client_instance_cache[api_provider.name] = client_class(api_provider)
|
self.client_instance_cache[api_provider.name] = registration.factory(api_provider)
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
|
||||||
return self.client_instance_cache[api_provider.name]
|
return self.client_instance_cache[api_provider.name]
|
||||||
|
|||||||
191
src/llm_models/model_client/plugin_client.py
Normal file
191
src/llm_models/model_client/plugin_client.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from src.config.model_configs import APIProvider
|
||||||
|
from src.llm_models.exceptions import RespParseException
|
||||||
|
from src.llm_models.model_client.base_client import (
|
||||||
|
APIResponse,
|
||||||
|
AudioTranscriptionRequest,
|
||||||
|
BaseClient,
|
||||||
|
EmbeddingRequest,
|
||||||
|
ResponseRequest,
|
||||||
|
UsageRecord,
|
||||||
|
)
|
||||||
|
from src.llm_models.request_snapshot import (
|
||||||
|
deserialize_tool_calls_snapshot,
|
||||||
|
serialize_audio_request_snapshot,
|
||||||
|
serialize_embedding_request_snapshot,
|
||||||
|
serialize_response_request_snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginLLMClient(BaseClient):
|
||||||
|
"""通过插件 Runner RPC 调用第三方 LLM Provider 的客户端代理。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_provider: APIProvider,
|
||||||
|
supervisor: Any,
|
||||||
|
plugin_id: str,
|
||||||
|
client_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""初始化插件 LLM Provider 客户端代理。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_provider: 当前 API Provider 配置。
|
||||||
|
supervisor: 拥有目标插件的 Supervisor。
|
||||||
|
plugin_id: 目标插件 ID。
|
||||||
|
client_type: 目标客户端类型。
|
||||||
|
"""
|
||||||
|
super().__init__(api_provider)
|
||||||
|
self._supervisor = supervisor
|
||||||
|
self._plugin_id = plugin_id
|
||||||
|
self._client_type = client_type
|
||||||
|
|
||||||
|
async def get_response(self, request: ResponseRequest) -> APIResponse:
|
||||||
|
"""获取对话响应。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 统一响应请求对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
APIResponse: 统一响应对象。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RespParseException: 插件返回内容无法转换为统一响应时抛出。
|
||||||
|
"""
|
||||||
|
if request.stream_response_handler is not None or request.async_response_parser is not None:
|
||||||
|
raise RespParseException(message="插件 LLM Provider 暂不支持 Host 侧自定义流式处理器或响应解析器")
|
||||||
|
payload = serialize_response_request_snapshot(request)
|
||||||
|
result = await self._invoke_provider("response", payload)
|
||||||
|
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
|
||||||
|
|
||||||
|
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
|
||||||
|
"""获取文本嵌入。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 统一嵌入请求对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
APIResponse: 嵌入响应。
|
||||||
|
"""
|
||||||
|
result = await self._invoke_provider("embedding", serialize_embedding_request_snapshot(request))
|
||||||
|
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
|
||||||
|
|
||||||
|
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
|
||||||
|
"""获取音频转录。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 统一音频转录请求对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
APIResponse: 音频转录响应。
|
||||||
|
"""
|
||||||
|
result = await self._invoke_provider("audio_transcription", serialize_audio_request_snapshot(request))
|
||||||
|
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
|
||||||
|
|
||||||
|
def get_support_image_formats(self) -> List[str]:
|
||||||
|
"""获取支持的图片格式。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 插件 Provider 默认接收的图片格式列表。
|
||||||
|
"""
|
||||||
|
return ["jpeg", "jpg", "png", "webp"]
|
||||||
|
|
||||||
|
def _build_api_provider_snapshot(self) -> Dict[str, Any]:
|
||||||
|
"""构建可传给插件 Provider 的 API Provider 配置快照。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: 包含认证信息的 API Provider 配置字典。
|
||||||
|
"""
|
||||||
|
return self.api_provider.model_dump(mode="json")
|
||||||
|
|
||||||
|
async def _invoke_provider(self, operation: str, request_payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""调用插件 Provider。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation: 请求操作类型。
|
||||||
|
request_payload: 已序列化的内部请求。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: 插件返回的响应字典。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RespParseException: 插件调用失败或返回格式不合法时抛出。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await self._supervisor.invoke_llm_provider(
|
||||||
|
plugin_id=self._plugin_id,
|
||||||
|
client_type=self._client_type,
|
||||||
|
operation=operation,
|
||||||
|
request={
|
||||||
|
**request_payload,
|
||||||
|
"api_provider": self._build_api_provider_snapshot(),
|
||||||
|
},
|
||||||
|
timeout_ms=max(1000, int(self.api_provider.timeout) * 1000),
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
raise RespParseException(message=f"插件 LLM Provider RPC 调用失败: {exc}") from exc
|
||||||
|
if response.error:
|
||||||
|
raise RespParseException(message=str(response.error.get("message", "插件 LLM Provider 调用失败")))
|
||||||
|
|
||||||
|
payload = response.payload if isinstance(response.payload, dict) else {}
|
||||||
|
success = bool(payload.get("success", False))
|
||||||
|
result = payload.get("result")
|
||||||
|
if not success:
|
||||||
|
raise RespParseException(message=str(result or "插件 LLM Provider 返回失败"))
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
raise RespParseException(message="插件 LLM Provider 返回值必须是字典")
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_usage_record(raw_usage: Any, model_name: str, provider_name: str) -> UsageRecord | None:
|
||||||
|
"""从插件返回值恢复使用量记录。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_usage: 插件返回的使用量字段。
|
||||||
|
model_name: 当前模型名称。
|
||||||
|
provider_name: 当前 Provider 名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UsageRecord | None: 可挂载到 APIResponse 的使用量记录;缺失时返回 ``None``。
|
||||||
|
"""
|
||||||
|
if not isinstance(raw_usage, dict):
|
||||||
|
return None
|
||||||
|
return UsageRecord(
|
||||||
|
model_name=str(raw_usage.get("model_name") or model_name),
|
||||||
|
provider_name=str(raw_usage.get("provider_name") or provider_name),
|
||||||
|
prompt_tokens=int(raw_usage.get("prompt_tokens") or 0),
|
||||||
|
completion_tokens=int(raw_usage.get("completion_tokens") or 0),
|
||||||
|
total_tokens=int(raw_usage.get("total_tokens") or 0),
|
||||||
|
prompt_cache_hit_tokens=int(raw_usage.get("prompt_cache_hit_tokens") or 0),
|
||||||
|
prompt_cache_miss_tokens=int(raw_usage.get("prompt_cache_miss_tokens") or 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_api_response(result: Dict[str, Any], model_name: str, provider_name: str) -> APIResponse:
|
||||||
|
"""从插件返回值恢复统一 APIResponse。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: 插件返回的响应字典。
|
||||||
|
model_name: 当前模型名称。
|
||||||
|
provider_name: 当前 Provider 名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
APIResponse: 统一响应对象。
|
||||||
|
"""
|
||||||
|
raw_embedding = result.get("embedding")
|
||||||
|
embedding = [float(item) for item in raw_embedding] if isinstance(raw_embedding, list) else None
|
||||||
|
content = result.get("content")
|
||||||
|
if not isinstance(content, str):
|
||||||
|
content = result.get("response")
|
||||||
|
reasoning_content = result.get("reasoning_content")
|
||||||
|
if not isinstance(reasoning_content, str):
|
||||||
|
reasoning_content = result.get("reasoning")
|
||||||
|
return APIResponse(
|
||||||
|
content=content if isinstance(content, str) else None,
|
||||||
|
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
||||||
|
tool_calls=deserialize_tool_calls_snapshot(result.get("tool_calls")) or None,
|
||||||
|
embedding=embedding,
|
||||||
|
usage=PluginLLMClient._build_usage_record(result.get("usage"), model_name, provider_name),
|
||||||
|
raw_data=result.get("raw_data", result),
|
||||||
|
)
|
||||||
@@ -1072,3 +1072,65 @@ class TempMethodsLLMUtils:
|
|||||||
if provider.name == provider_name:
|
if provider.name == provider_name:
|
||||||
return provider
|
return provider
|
||||||
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
|
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
|
||||||
|
|
||||||
|
|
||||||
|
class LLMRequest(LLMOrchestrator):
|
||||||
|
"""兼容旧调用方的 LLM 请求入口。
|
||||||
|
|
||||||
|
新代码应优先使用 ``LLMOrchestrator`` 或服务层 ``LLMServiceClient``;
|
||||||
|
该类保留旧版 ``model_set=TaskConfig`` 的构造方式,并在运行时解析
|
||||||
|
对应的最新任务配置名称。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
|
||||||
|
"""初始化旧版 LLM 请求对象。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_set: 旧调用方传入的任务配置对象。
|
||||||
|
request_type: 当前请求的业务类型标识。
|
||||||
|
"""
|
||||||
|
self._task_config_name = self._resolve_task_config_name(model_set)
|
||||||
|
super().__init__(task_name=self._task_config_name, request_type=request_type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_task_config_signature(task_config: TaskConfig) -> Tuple[Any, ...]:
|
||||||
|
"""构造任务配置签名。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_config: 任务配置对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Any, ...]: 可用于匹配热重载前后任务配置的签名。
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
tuple(task_config.model_list),
|
||||||
|
task_config.max_tokens,
|
||||||
|
task_config.temperature,
|
||||||
|
task_config.slow_threshold,
|
||||||
|
task_config.selection_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _resolve_task_config_name(cls, model_set: TaskConfig) -> str:
|
||||||
|
"""根据旧版 TaskConfig 对象解析任务配置名称。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_set: 旧调用方传入的任务配置对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 对应 ``model_task_config`` 下的字段名。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 未能找到匹配任务配置时抛出。
|
||||||
|
"""
|
||||||
|
target_signature = cls._build_task_config_signature(model_set)
|
||||||
|
model_task_config = config_manager.get_model_config().model_task_config
|
||||||
|
for attr_name in dir(model_task_config):
|
||||||
|
if attr_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
attr_value = getattr(model_task_config, attr_name)
|
||||||
|
if not isinstance(attr_value, TaskConfig):
|
||||||
|
continue
|
||||||
|
if attr_value is model_set or cls._build_task_config_signature(attr_value) == target_signature:
|
||||||
|
return attr_name
|
||||||
|
raise ValueError("无法根据旧版 model_set 解析任务配置名称")
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import sys
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import config_manager, global_config
|
from src.config.config import config_manager, global_config
|
||||||
|
from src.llm_models.model_client.base_client import ClientProviderRegistration, client_registry
|
||||||
|
from src.llm_models.model_client.plugin_client import PluginLLMClient
|
||||||
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
|
||||||
from src.platform_io.drivers import PluginPlatformDriver
|
from src.platform_io.drivers import PluginPlatformDriver
|
||||||
from src.platform_io.route_key_factory import RouteKeyFactory
|
from src.platform_io.route_key_factory import RouteKeyFactory
|
||||||
@@ -30,6 +32,7 @@ from src.plugin_runtime.protocol.envelope import (
|
|||||||
HealthPayload,
|
HealthPayload,
|
||||||
InspectPluginConfigPayload,
|
InspectPluginConfigPayload,
|
||||||
InspectPluginConfigResultPayload,
|
InspectPluginConfigResultPayload,
|
||||||
|
LLMProviderInvokePayload,
|
||||||
MessageGatewayStateUpdatePayload,
|
MessageGatewayStateUpdatePayload,
|
||||||
MessageGatewayStateUpdateResultPayload,
|
MessageGatewayStateUpdateResultPayload,
|
||||||
PROTOCOL_VERSION,
|
PROTOCOL_VERSION,
|
||||||
@@ -417,6 +420,38 @@ class PluginRunnerSupervisor:
|
|||||||
timeout_ms=timeout_ms,
|
timeout_ms=timeout_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def invoke_llm_provider(
|
||||||
|
self,
|
||||||
|
plugin_id: str,
|
||||||
|
client_type: str,
|
||||||
|
operation: str,
|
||||||
|
request: Optional[Dict[str, Any]] = None,
|
||||||
|
timeout_ms: int = 30000,
|
||||||
|
) -> Envelope:
|
||||||
|
"""调用插件声明的 LLM Provider。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_id: 目标插件 ID。
|
||||||
|
client_type: 目标客户端类型。
|
||||||
|
operation: 请求操作类型。
|
||||||
|
request: 已序列化的 LLM 请求。
|
||||||
|
timeout_ms: RPC 超时时间,单位毫秒。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Envelope: Runner 返回的响应信封。
|
||||||
|
"""
|
||||||
|
payload = LLMProviderInvokePayload(
|
||||||
|
client_type=client_type,
|
||||||
|
operation=operation,
|
||||||
|
request=request or {},
|
||||||
|
)
|
||||||
|
return await self._rpc_server.send_request(
|
||||||
|
"plugin.invoke_llm_provider",
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
payload=payload.model_dump(),
|
||||||
|
timeout_ms=timeout_ms,
|
||||||
|
)
|
||||||
|
|
||||||
async def invoke_api(
|
async def invoke_api(
|
||||||
self,
|
self,
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
@@ -779,6 +814,22 @@ class PluginRunnerSupervisor:
|
|||||||
|
|
||||||
component_declarations = [component.model_dump() for component in payload.components]
|
component_declarations = [component.model_dump() for component in payload.components]
|
||||||
runtime_components, api_components = self._split_component_declarations(component_declarations)
|
runtime_components, api_components = self._split_component_declarations(component_declarations)
|
||||||
|
try:
|
||||||
|
client_registry.validate_plugin_provider_replacement(
|
||||||
|
payload.plugin_id,
|
||||||
|
[provider.client_type for provider in payload.llm_providers],
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"插件 {payload.plugin_id} LLM Provider 注册校验失败: {exc}")
|
||||||
|
return envelope.make_error_response(
|
||||||
|
ErrorCode.E_BAD_PAYLOAD.value,
|
||||||
|
str(exc),
|
||||||
|
details={
|
||||||
|
"plugin_id": payload.plugin_id,
|
||||||
|
"llm_provider_count": len(payload.llm_providers),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
registered_count = self._component_registry.register_plugin_components(
|
registered_count = self._component_registry.register_plugin_components(
|
||||||
payload.plugin_id,
|
payload.plugin_id,
|
||||||
@@ -798,6 +849,24 @@ class PluginRunnerSupervisor:
|
|||||||
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
||||||
registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
|
registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
|
||||||
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
||||||
|
client_registry.replace_plugin_providers(
|
||||||
|
payload.plugin_id,
|
||||||
|
[
|
||||||
|
ClientProviderRegistration(
|
||||||
|
client_type=provider.client_type,
|
||||||
|
factory=lambda api_provider, provider_client_type=provider.client_type: PluginLLMClient(
|
||||||
|
api_provider=api_provider,
|
||||||
|
supervisor=self,
|
||||||
|
plugin_id=payload.plugin_id,
|
||||||
|
client_type=provider_client_type,
|
||||||
|
),
|
||||||
|
owner_plugin_id=payload.plugin_id,
|
||||||
|
version=provider.version,
|
||||||
|
description=provider.description or provider.name,
|
||||||
|
)
|
||||||
|
for provider in payload.llm_providers
|
||||||
|
],
|
||||||
|
)
|
||||||
self._registered_plugins[payload.plugin_id] = payload
|
self._registered_plugins[payload.plugin_id] = payload
|
||||||
self._message_gateway_states[payload.plugin_id] = {}
|
self._message_gateway_states[payload.plugin_id] = {}
|
||||||
|
|
||||||
@@ -810,6 +879,7 @@ class PluginRunnerSupervisor:
|
|||||||
"message_gateways": len(
|
"message_gateways": len(
|
||||||
self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False)
|
self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False)
|
||||||
),
|
),
|
||||||
|
"llm_providers": len(payload.llm_providers),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -829,6 +899,7 @@ class PluginRunnerSupervisor:
|
|||||||
|
|
||||||
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
|
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
|
||||||
removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id)
|
||||||
|
removed_llm_providers = client_registry.unregister_plugin_providers(payload.plugin_id)
|
||||||
self._authorization.revoke_permission_token(payload.plugin_id)
|
self._authorization.revoke_permission_token(payload.plugin_id)
|
||||||
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
|
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
|
||||||
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
|
||||||
@@ -841,6 +912,7 @@ class PluginRunnerSupervisor:
|
|||||||
"reason": payload.reason,
|
"reason": payload.reason,
|
||||||
"removed_components": removed_components,
|
"removed_components": removed_components,
|
||||||
"removed_apis": removed_apis,
|
"removed_apis": removed_apis,
|
||||||
|
"removed_llm_providers": removed_llm_providers,
|
||||||
"removed_registration": removed_registration,
|
"removed_registration": removed_registration,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -1505,6 +1577,8 @@ class PluginRunnerSupervisor:
|
|||||||
|
|
||||||
def _clear_runner_state(self) -> None:
|
def _clear_runner_state(self) -> None:
|
||||||
"""清理当前 Runner 对应的 Host 侧注册状态。"""
|
"""清理当前 Runner 对应的 Host 侧注册状态。"""
|
||||||
|
for plugin_id in list(self._registered_plugins):
|
||||||
|
client_registry.unregister_plugin_providers(plugin_id)
|
||||||
self._authorization.clear()
|
self._authorization.clear()
|
||||||
self._api_registry.clear()
|
self._api_registry.clear()
|
||||||
self._component_registry.clear()
|
self._component_registry.clear()
|
||||||
|
|||||||
@@ -148,6 +148,35 @@ class PluginRuntimeManager(
|
|||||||
validator = ManifestValidator(validate_python_package_dependencies=False)
|
validator = ManifestValidator(validate_python_package_dependencies=False)
|
||||||
return validator.build_plugin_dependency_map(plugin_dirs)
|
return validator.build_plugin_dependency_map(plugin_dirs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _discover_llm_provider_conflicts(cls, plugin_dirs: Iterable[Path]) -> Dict[str, str]:
|
||||||
|
"""扫描插件 Manifest,发现 LLM Provider client_type 冲突。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
plugin_dirs: 需要扫描的插件根目录集合。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: 需要阻止加载的插件 ID 与原因映射。
|
||||||
|
"""
|
||||||
|
validator = ManifestValidator(validate_python_package_dependencies=False)
|
||||||
|
provider_owners: Dict[str, List[str]] = {}
|
||||||
|
for _plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs, require_entrypoint=True):
|
||||||
|
for client_type in manifest.llm_provider_client_types:
|
||||||
|
provider_owners.setdefault(client_type, []).append(manifest.id)
|
||||||
|
|
||||||
|
blocked_reasons: Dict[str, str] = {}
|
||||||
|
for client_type, plugin_ids in provider_owners.items():
|
||||||
|
unique_plugin_ids = sorted(set(plugin_ids))
|
||||||
|
if len(unique_plugin_ids) <= 1:
|
||||||
|
continue
|
||||||
|
reason = (
|
||||||
|
f"LLM Provider client_type 冲突: {client_type} 被以下插件重复声明: "
|
||||||
|
f"{', '.join(unique_plugin_ids)}"
|
||||||
|
)
|
||||||
|
for plugin_id in unique_plugin_ids:
|
||||||
|
blocked_reasons[plugin_id] = reason
|
||||||
|
return blocked_reasons
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_group_start_order(
|
def _build_group_start_order(
|
||||||
cls,
|
cls,
|
||||||
@@ -271,7 +300,11 @@ class PluginRuntimeManager(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
result = await self._plugin_dependency_pipeline.execute(plugin_dirs)
|
result = await self._plugin_dependency_pipeline.execute(plugin_dirs)
|
||||||
changed_plugin_ids = self._set_blocked_plugin_reasons(result.blocked_plugin_reasons)
|
blocked_plugin_reasons = {
|
||||||
|
**result.blocked_plugin_reasons,
|
||||||
|
**self._discover_llm_provider_conflicts(plugin_dirs),
|
||||||
|
}
|
||||||
|
changed_plugin_ids = self._set_blocked_plugin_reasons(blocked_plugin_reasons)
|
||||||
return DependencySyncState(
|
return DependencySyncState(
|
||||||
blocked_changed_plugin_ids=changed_plugin_ids,
|
blocked_changed_plugin_ids=changed_plugin_ids,
|
||||||
environment_changed=result.environment_changed,
|
environment_changed=result.environment_changed,
|
||||||
|
|||||||
@@ -199,6 +199,21 @@ class ComponentDeclaration(BaseModel):
|
|||||||
"""组件元数据"""
|
"""组件元数据"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProviderDeclaration(BaseModel):
|
||||||
|
"""单个 LLM Provider 声明。"""
|
||||||
|
|
||||||
|
client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type")
|
||||||
|
"""客户端类型标识。"""
|
||||||
|
name: str = Field(default="", description="Provider 展示名称")
|
||||||
|
"""Provider 展示名称。"""
|
||||||
|
description: str = Field(default="", description="Provider 描述")
|
||||||
|
"""Provider 描述。"""
|
||||||
|
version: str = Field(default="1.0.0", description="Provider 实现版本")
|
||||||
|
"""Provider 实现版本。"""
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Provider 元数据")
|
||||||
|
"""Provider 元数据。"""
|
||||||
|
|
||||||
|
|
||||||
class RegisterPluginPayload(BaseModel):
|
class RegisterPluginPayload(BaseModel):
|
||||||
"""插件组件注册请求载荷。
|
"""插件组件注册请求载荷。
|
||||||
|
|
||||||
@@ -212,6 +227,8 @@ class RegisterPluginPayload(BaseModel):
|
|||||||
"""插件版本"""
|
"""插件版本"""
|
||||||
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
|
||||||
"""组件列表"""
|
"""组件列表"""
|
||||||
|
llm_providers: List[LLMProviderDeclaration] = Field(default_factory=list, description="LLM Provider 声明列表")
|
||||||
|
"""LLM Provider 声明列表。"""
|
||||||
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
|
||||||
"""所需能力列表"""
|
"""所需能力列表"""
|
||||||
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
|
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
|
||||||
@@ -254,6 +271,17 @@ class InvokeResultPayload(BaseModel):
|
|||||||
"""返回值"""
|
"""返回值"""
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProviderInvokePayload(BaseModel):
|
||||||
|
"""plugin.invoke_llm_provider 请求 payload。"""
|
||||||
|
|
||||||
|
client_type: str = Field(description="目标 LLM Provider 客户端类型")
|
||||||
|
"""目标 LLM Provider 客户端类型。"""
|
||||||
|
operation: str = Field(description="请求操作类型")
|
||||||
|
"""请求操作类型,如 response、embedding、audio_transcription。"""
|
||||||
|
request: Dict[str, Any] = Field(default_factory=dict, description="已序列化的 LLM 请求")
|
||||||
|
"""已序列化的 LLM 请求。"""
|
||||||
|
|
||||||
|
|
||||||
# ====== 能力调用消息 ======
|
# ====== 能力调用消息 ======
|
||||||
class CapabilityRequestPayload(BaseModel):
|
class CapabilityRequestPayload(BaseModel):
|
||||||
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
|
"""cap.* 请求 payload(插件 -> Host 能力调用)"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from importlib import metadata as importlib_metadata
|
from importlib import metadata as importlib_metadata
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
|
from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
@@ -453,6 +453,38 @@ class PythonPackageDependencyDefinition(_StrictManifestModel):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class LLMProviderManifestDeclaration(_StrictManifestModel):
|
||||||
|
"""插件 Manifest 中声明的 LLM Provider。"""
|
||||||
|
|
||||||
|
client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type")
|
||||||
|
"""客户端类型标识。"""
|
||||||
|
name: str = Field(default="", description="Provider 展示名称")
|
||||||
|
"""Provider 展示名称。"""
|
||||||
|
description: str = Field(default="", description="Provider 描述")
|
||||||
|
"""Provider 描述。"""
|
||||||
|
version: str = Field(default="1.0.0", description="Provider 实现版本")
|
||||||
|
"""Provider 实现版本。"""
|
||||||
|
|
||||||
|
@field_validator("client_type")
|
||||||
|
@classmethod
|
||||||
|
def _validate_client_type(cls, value: str) -> str:
|
||||||
|
"""校验客户端类型标识。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: 原始客户端类型标识。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 合法的客户端类型标识。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当客户端类型为空时抛出。
|
||||||
|
"""
|
||||||
|
normalized_value = str(value or "").strip()
|
||||||
|
if not normalized_value:
|
||||||
|
raise ValueError("client_type 不能为空")
|
||||||
|
return normalized_value
|
||||||
|
|
||||||
|
|
||||||
ManifestDependencyDefinition = Annotated[
|
ManifestDependencyDefinition = Annotated[
|
||||||
Union[PluginDependencyDefinition, PythonPackageDependencyDefinition],
|
Union[PluginDependencyDefinition, PythonPackageDependencyDefinition],
|
||||||
Field(discriminator="type"),
|
Field(discriminator="type"),
|
||||||
@@ -472,6 +504,10 @@ class PluginManifest(_StrictManifestModel):
|
|||||||
host_application: ManifestVersionRange = Field(description="Host 兼容区间")
|
host_application: ManifestVersionRange = Field(description="Host 兼容区间")
|
||||||
sdk: ManifestVersionRange = Field(description="SDK 兼容区间")
|
sdk: ManifestVersionRange = Field(description="SDK 兼容区间")
|
||||||
dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明")
|
dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明")
|
||||||
|
llm_providers: List[LLMProviderManifestDeclaration] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="插件静态声明的 LLM Provider 列表",
|
||||||
|
)
|
||||||
capabilities: List[str] = Field(description="插件声明的能力请求")
|
capabilities: List[str] = Field(description="插件声明的能力请求")
|
||||||
i18n: ManifestI18n = Field(description="国际化配置")
|
i18n: ManifestI18n = Field(description="国际化配置")
|
||||||
id: str = Field(description="稳定插件 ID")
|
id: str = Field(description="稳定插件 ID")
|
||||||
@@ -567,6 +603,23 @@ class PluginManifest(_StrictManifestModel):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _validate_llm_providers(self) -> "PluginManifest":
|
||||||
|
"""校验 LLM Provider 静态声明集合。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PluginManifest: 当前对象本身。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当同一 Manifest 内重复声明 client_type 时抛出。
|
||||||
|
"""
|
||||||
|
client_types: Set[str] = set()
|
||||||
|
for provider in self.llm_providers:
|
||||||
|
if provider.client_type in client_types:
|
||||||
|
raise ValueError(f"存在重复的 LLM Provider 声明: {provider.client_type}")
|
||||||
|
client_types.add(provider.client_type)
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def plugin_dependencies(self) -> List[PluginDependencyDefinition]:
|
def plugin_dependencies(self) -> List[PluginDependencyDefinition]:
|
||||||
"""返回插件级依赖列表。
|
"""返回插件级依赖列表。
|
||||||
@@ -598,6 +651,15 @@ class PluginManifest(_StrictManifestModel):
|
|||||||
"""
|
"""
|
||||||
return [dependency.id for dependency in self.plugin_dependencies]
|
return [dependency.id for dependency in self.plugin_dependencies]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_provider_client_types(self) -> List[str]:
|
||||||
|
"""返回 Manifest 静态声明的 LLM Provider client_type 列表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 当前插件声明的 LLM Provider client_type。
|
||||||
|
"""
|
||||||
|
return [provider.client_type for provider in self.llm_providers]
|
||||||
|
|
||||||
|
|
||||||
class ManifestValidator:
|
class ManifestValidator:
|
||||||
"""严格的插件 Manifest v2 校验器。"""
|
"""严格的插件 Manifest v2 校验器。"""
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class PluginMeta:
|
|||||||
self.capabilities_required = list(manifest.capabilities)
|
self.capabilities_required = list(manifest.capabilities)
|
||||||
self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
|
self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
|
||||||
self.component_handlers: Dict[str, str] = {}
|
self.component_handlers: Dict[str, str] = {}
|
||||||
|
self.llm_provider_handlers: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
class PluginLoader:
|
class PluginLoader:
|
||||||
|
|||||||
@@ -48,6 +48,8 @@ from src.plugin_runtime.protocol.envelope import (
|
|||||||
InspectPluginConfigResultPayload,
|
InspectPluginConfigResultPayload,
|
||||||
InvokePayload,
|
InvokePayload,
|
||||||
InvokeResultPayload,
|
InvokeResultPayload,
|
||||||
|
LLMProviderDeclaration,
|
||||||
|
LLMProviderInvokePayload,
|
||||||
RegisterPluginPayload,
|
RegisterPluginPayload,
|
||||||
ReloadPluginPayload,
|
ReloadPluginPayload,
|
||||||
ReloadPluginResultPayload,
|
ReloadPluginResultPayload,
|
||||||
@@ -891,6 +893,7 @@ class PluginRunner:
|
|||||||
self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke)
|
self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke)
|
||||||
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
|
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
|
||||||
self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
|
self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
|
||||||
|
self._rpc_client.register_method("plugin.invoke_llm_provider", self._handle_llm_provider_invoke)
|
||||||
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
|
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
|
||||||
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
|
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
|
||||||
self._rpc_client.register_method("plugin.health", self._handle_health)
|
self._rpc_client.register_method("plugin.health", self._handle_health)
|
||||||
@@ -980,6 +983,7 @@ class PluginRunner:
|
|||||||
"""
|
"""
|
||||||
# 收集插件组件声明
|
# 收集插件组件声明
|
||||||
components: List[ComponentDeclaration] = []
|
components: List[ComponentDeclaration] = []
|
||||||
|
llm_providers: List[LLMProviderDeclaration] = []
|
||||||
config_reload_subscriptions: List[str] = []
|
config_reload_subscriptions: List[str] = []
|
||||||
instance = meta.instance
|
instance = meta.instance
|
||||||
|
|
||||||
@@ -1016,11 +1020,43 @@ class PluginRunner:
|
|||||||
)
|
)
|
||||||
if hasattr(instance, "get_config_reload_subscriptions"):
|
if hasattr(instance, "get_config_reload_subscriptions"):
|
||||||
config_reload_subscriptions = list(instance.get_config_reload_subscriptions())
|
config_reload_subscriptions = list(instance.get_config_reload_subscriptions())
|
||||||
|
if hasattr(instance, "get_llm_providers"):
|
||||||
|
meta.llm_provider_handlers.clear()
|
||||||
|
for provider_info in instance.get_llm_providers():
|
||||||
|
if not isinstance(provider_info, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
client_type = str(provider_info.get("client_type", "") or "").strip()
|
||||||
|
raw_metadata = provider_info.get("metadata", {})
|
||||||
|
provider_metadata = raw_metadata if isinstance(raw_metadata, dict) else {}
|
||||||
|
if client_type:
|
||||||
|
handler_name = str(provider_metadata.get("handler_name", client_type) or client_type).strip()
|
||||||
|
meta.llm_provider_handlers[client_type] = handler_name or client_type
|
||||||
|
|
||||||
|
llm_providers.append(
|
||||||
|
LLMProviderDeclaration(
|
||||||
|
client_type=client_type,
|
||||||
|
name=str(provider_info.get("name", "") or "").strip(),
|
||||||
|
description=str(provider_info.get("description", "") or "").strip(),
|
||||||
|
version=str(provider_info.get("version", "1.0.0") or "1.0.0").strip() or "1.0.0",
|
||||||
|
metadata=provider_metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
declared_client_types = sorted(meta.manifest.llm_provider_client_types)
|
||||||
|
registered_client_types = sorted(provider.client_type for provider in llm_providers)
|
||||||
|
if declared_client_types != registered_client_types:
|
||||||
|
logger.error(
|
||||||
|
f"插件 {meta.plugin_id} LLM Provider 声明不一致: "
|
||||||
|
f"manifest={declared_client_types}, code={registered_client_types}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
reg_payload = RegisterPluginPayload(
|
reg_payload = RegisterPluginPayload(
|
||||||
plugin_id=meta.plugin_id,
|
plugin_id=meta.plugin_id,
|
||||||
plugin_version=meta.version,
|
plugin_version=meta.version,
|
||||||
components=components,
|
components=components,
|
||||||
|
llm_providers=llm_providers,
|
||||||
capabilities_required=meta.capabilities_required,
|
capabilities_required=meta.capabilities_required,
|
||||||
dependencies=meta.dependencies,
|
dependencies=meta.dependencies,
|
||||||
config_reload_subscriptions=config_reload_subscriptions,
|
config_reload_subscriptions=config_reload_subscriptions,
|
||||||
@@ -1629,6 +1665,50 @@ class PluginRunner:
|
|||||||
resp_payload = InvokeResultPayload(success=False, result=str(e))
|
resp_payload = InvokeResultPayload(success=False, result=str(e))
|
||||||
return envelope.make_response(payload=resp_payload.model_dump())
|
return envelope.make_response(payload=resp_payload.model_dump())
|
||||||
|
|
||||||
|
async def _handle_llm_provider_invoke(self, envelope: Envelope) -> Envelope:
|
||||||
|
"""处理 LLM Provider 调用请求。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
envelope: RPC 请求信封。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Envelope: 标准化后的 Provider 调用结果。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
invoke = LLMProviderInvokePayload.model_validate(envelope.payload)
|
||||||
|
except Exception as exc:
|
||||||
|
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
|
||||||
|
|
||||||
|
plugin_id = envelope.plugin_id
|
||||||
|
meta = self._loader.get_plugin(plugin_id)
|
||||||
|
if meta is None:
|
||||||
|
return envelope.make_error_response(
|
||||||
|
ErrorCode.E_PLUGIN_NOT_FOUND.value,
|
||||||
|
f"插件 {plugin_id} 未加载",
|
||||||
|
)
|
||||||
|
|
||||||
|
handler_name = meta.llm_provider_handlers.get(invoke.client_type, "")
|
||||||
|
handler_method = getattr(meta.instance, handler_name, None) if handler_name else None
|
||||||
|
if handler_method is None or not callable(handler_method):
|
||||||
|
return envelope.make_error_response(
|
||||||
|
ErrorCode.E_METHOD_NOT_ALLOWED.value,
|
||||||
|
f"插件 {plugin_id} 未注册 LLM Provider: {invoke.client_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = handler_method(operation=invoke.operation, request=invoke.request)
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
result = await result
|
||||||
|
resp_payload = InvokeResultPayload(success=True, result=result)
|
||||||
|
return envelope.make_response(payload=resp_payload.model_dump())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
f"插件 {plugin_id} LLM Provider {invoke.client_type} 执行异常: {exc}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
resp_payload = InvokeResultPayload(success=False, result=str(exc))
|
||||||
|
return envelope.make_response(payload=resp_payload.model_dump())
|
||||||
|
|
||||||
async def _handle_event_invoke(self, envelope: Envelope) -> Envelope:
|
async def _handle_event_invoke(self, envelope: Envelope) -> Envelope:
|
||||||
"""处理 EventHandler 调用请求
|
"""处理 EventHandler 调用请求
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user