Merge branch 'Mai-with-u:dev' into dev

This commit is contained in:
Dawn ARC
2026-04-28 10:02:59 +08:00
committed by GitHub
14 changed files with 948 additions and 18 deletions

View File

@@ -16,7 +16,7 @@ async def test_handle_file_changes_throttles_reload():
called = 0
async def reload_stub() -> bool:
async def reload_stub(changed_scopes=None) -> bool:
nonlocal called
called += 1
return True
@@ -36,7 +36,7 @@ async def test_handle_file_changes_timeout_logged(caplog):
manager._hot_reload_min_interval_s = 0.0
manager._hot_reload_timeout_s = 0.01
async def reload_stub() -> bool:
async def reload_stub(changed_scopes=None) -> bool:
await asyncio.sleep(0.05)
return True
@@ -55,7 +55,7 @@ async def test_handle_file_changes_empty_skips_reload():
called = 0
async def reload_stub() -> bool:
async def reload_stub(changed_scopes=None) -> bool:
nonlocal called
called += 1
return True

View File

@@ -0,0 +1,33 @@
from typing import Any
import pytest
from src.config import config as config_module
from src.config.config import Config, ConfigManager, ModelConfig
class _StartupUpgradeExit(Exception):
pass
def test_initialize_upgrades_bot_and_model_config_before_exit(monkeypatch):
manager = ConfigManager()
loaded_config_classes: list[type[Any]] = []
exit_codes: list[int | None] = []
def fake_load_config_from_file(config_class, config_path, new_ver, override_repr=False):
loaded_config_classes.append(config_class)
return object(), True
def fake_exit(code: int | None = None):
exit_codes.append(code)
raise _StartupUpgradeExit
monkeypatch.setattr(config_module, "load_config_from_file", fake_load_config_from_file)
monkeypatch.setattr(config_module.sys, "exit", fake_exit)
with pytest.raises(_StartupUpgradeExit):
manager.initialize()
assert loaded_config_classes == [Config, ModelConfig]
assert exit_codes == [0]

View File

@@ -0,0 +1,101 @@
from typing import List
from src.llm_models.model_client.base_client import (
APIResponse,
AudioTranscriptionRequest,
BaseClient,
ClientProviderRegistration,
ClientRegistry,
EmbeddingRequest,
ResponseRequest,
)
class DummyClient(BaseClient):
"""测试用 LLM 客户端。"""
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取测试响应。
Args:
request: 统一响应请求。
Returns:
APIResponse: 测试响应。
"""
del request
return APIResponse(content="ok")
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取测试嵌入。
Args:
request: 统一嵌入请求。
Returns:
APIResponse: 测试嵌入响应。
"""
del request
return APIResponse(embedding=[1.0])
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取测试音频转写。
Args:
request: 统一音频转写请求。
Returns:
APIResponse: 测试音频转写响应。
"""
del request
return APIResponse(content="audio")
def get_support_image_formats(self) -> List[str]:
"""获取测试支持的图片格式。
Returns:
List[str]: 支持的图片格式列表。
"""
return ["png"]
def test_client_registry_rejects_provider_conflict():
"""同一 client_type 被不同插件注册时应拒绝。"""
registry = ClientRegistry()
registry.replace_plugin_providers(
"plugin.alpha",
[
ClientProviderRegistration(
client_type="example",
factory=DummyClient,
owner_plugin_id="plugin.alpha",
)
],
)
try:
registry.validate_plugin_provider_replacement("plugin.beta", ["example"])
except ValueError as exc:
assert "冲突" in str(exc)
else:
raise AssertionError("不同插件注册相同 client_type 应失败")
def test_client_registry_unregisters_plugin_providers():
"""插件注销时应移除它拥有的 Provider 注册。"""
registry = ClientRegistry()
registry.replace_plugin_providers(
"plugin.alpha",
[
ClientProviderRegistration(
client_type="example",
factory=DummyClient,
owner_plugin_id="plugin.alpha",
)
],
)
removed_count = registry.unregister_plugin_providers("plugin.alpha")
assert removed_count == 1
assert "example" not in registry.client_registry

View File

@@ -30,6 +30,7 @@ def build_test_manifest(
name: str = "测试插件",
description: str = "测试插件描述",
dependencies: list[dict[str, str]] | None = None,
llm_providers: list[dict[str, str]] | None = None,
capabilities: list[str] | None = None,
host_min_version: str = "0.12.0",
host_max_version: str = "1.0.0",
@@ -44,6 +45,7 @@ def build_test_manifest(
name: 展示名称。
description: 插件描述。
dependencies: 依赖声明列表。
llm_providers: LLM Provider 静态声明列表。
capabilities: 能力声明列表。
host_min_version: Host 最低支持版本。
host_max_version: Host 最高支持版本。
@@ -75,6 +77,7 @@ def build_test_manifest(
"max_version": sdk_max_version,
},
"dependencies": dependencies or [],
"llm_providers": llm_providers or [],
"capabilities": capabilities or [],
"i18n": {
"default_locale": "zh-CN",
@@ -89,6 +92,7 @@ def build_test_manifest_model(
*,
version: str = "1.0.0",
dependencies: list[dict[str, str]] | None = None,
llm_providers: list[dict[str, str]] | None = None,
capabilities: list[str] | None = None,
host_version: str = "1.0.0",
sdk_version: str = "2.0.1",
@@ -99,6 +103,7 @@ def build_test_manifest_model(
plugin_id: 插件 ID。
version: 插件版本。
dependencies: 依赖声明列表。
llm_providers: LLM Provider 静态声明列表。
capabilities: 能力声明列表。
host_version: 当前测试使用的 Host 版本。
sdk_version: 当前测试使用的 SDK 版本。
@@ -114,6 +119,7 @@ def build_test_manifest_model(
plugin_id,
version=version,
dependencies=dependencies,
llm_providers=llm_providers,
capabilities=capabilities,
)
)
@@ -1055,6 +1061,63 @@ class TestManifestValidator:
assert validator.validate(manifest) is False
assert any("Python 包依赖冲突" in error for error in validator.errors)
def test_llm_provider_manifest_declaration(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
manifest = build_test_manifest(
"test.llm-provider",
llm_providers=[
{
"client_type": "example.provider",
"name": "Example Provider",
"description": "测试 Provider",
"version": "1.0.0",
}
],
)
parsed_manifest = validator.parse_manifest(manifest)
assert parsed_manifest is not None
assert parsed_manifest.llm_provider_client_types == ["example.provider"]
def test_duplicate_llm_provider_manifest_declaration_is_rejected(self):
from src.plugin_runtime.runner.manifest_validator import ManifestValidator
validator = ManifestValidator(host_version="1.0.0", sdk_version="2.0.1")
manifest = build_test_manifest(
"test.llm-provider-duplicate",
llm_providers=[
{"client_type": "example.provider"},
{"client_type": "example.provider"},
],
)
assert validator.validate(manifest) is False
assert any("重复的 LLM Provider" in error for error in validator.errors)
def test_llm_provider_conflict_blocks_all_conflicting_plugins(tmp_path: Path):
from src.plugin_runtime.integration import PluginRuntimeManager
plugin_root = tmp_path / "plugins"
plugin_root.mkdir()
for plugin_id in ["test.provider-alpha", "test.provider-beta"]:
plugin_dir = plugin_root / plugin_id
plugin_dir.mkdir()
manifest = build_test_manifest(
plugin_id,
llm_providers=[{"client_type": "example.provider"}],
)
(plugin_dir / "_manifest.json").write_text(json.dumps(manifest), encoding="utf-8")
(plugin_dir / "plugin.py").write_text("def create_plugin():\n return None\n", encoding="utf-8")
blocked_reasons = PluginRuntimeManager._discover_llm_provider_conflicts([plugin_root])
assert set(blocked_reasons) == {"test.provider-alpha", "test.provider-beta"}
assert all("example.provider" in reason for reason in blocked_reasons.values())
class TestVersionComparator:
"""版本号比较器测试"""

View File

@@ -196,8 +196,15 @@ class ConfigManager:
def initialize(self):
logger.info(t("config.current_version", version=MMC_VERSION))
logger.info(t("config.loading"))
self.global_config = self.load_global_config()
self.model_config = self.load_model_config()
self.global_config, global_updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION)
self.model_config, model_updated = load_config_from_file(
ModelConfig,
self.model_config_path,
MODEL_CONFIG_VERSION,
True,
)
if global_updated or model_updated:
sys.exit(0) # 配置已自动升级,退出一次让用户确认新配置后再启动
logger.info(t("config.loaded"))
def load_global_config(self) -> Config:

View File

@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine, Dict, List, Tuple, Type
from typing import Any, Callable, Coroutine, Dict, List, Set, Tuple, Type
import asyncio
@@ -204,19 +204,48 @@ class BaseClient(ABC):
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
ClientFactory = Callable[[APIProvider], BaseClient]
"""根据 APIProvider 创建客户端实例的工厂函数。"""
@dataclass(slots=True)
class ClientProviderRegistration:
"""LLM Provider 客户端类型注册信息。"""
client_type: str
"""客户端类型标识,对应模型配置中的 `api_providers[].client_type`。"""
factory: ClientFactory
"""客户端实例工厂。"""
owner_plugin_id: str | None = None
"""拥有该客户端类型的插件 ID主程序内置类型为 ``None``。"""
version: str = "1.0.0"
"""Provider 实现版本。"""
description: str = ""
"""Provider 描述文本。"""
builtin: bool = False
"""是否为主程序内置 Provider。"""
class ClientRegistry:
"""客户端注册表。"""
def __init__(self) -> None:
"""初始化注册表并绑定配置重载回调。"""
self.client_registry: Dict[str, Type[BaseClient]] = {}
"""APIProvider.type -> BaseClient的映射表"""
self.client_registry: Dict[str, ClientProviderRegistration] = {}
"""APIProvider.client_type -> Provider 注册信息映射表"""
self.client_instance_cache: Dict[str, BaseClient] = {}
"""APIProvider.name -> BaseClient的映射表"""
self._owner_client_types: Dict[str, Set[str]] = {}
"""插件 ID -> 该插件拥有的 client_type 集合。"""
config_manager.register_reload_callback(self.clear_client_instance_cache)
def register_client_class(self, client_type: str) -> Callable[[Type[BaseClient]], Type[BaseClient]]:
"""注册 API 客户端类。
"""注册主程序内置 API 客户端类。
Args:
client_type: 客户端类型标识。
@@ -226,13 +255,180 @@ class ClientRegistry:
"""
def decorator(cls: Type[BaseClient]) -> Type[BaseClient]:
"""将内置客户端类注册到全局客户端注册表。
Args:
cls: 待注册的客户端类。
Returns:
Type[BaseClient]: 原始客户端类。
"""
if not issubclass(cls, BaseClient):
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
self.client_registry[client_type] = cls
self.register_provider(
ClientProviderRegistration(
client_type=client_type,
factory=cls,
builtin=True,
)
)
return cls
return decorator
@staticmethod
def _normalize_client_type(client_type: str) -> str:
"""规范化客户端类型标识。
Args:
client_type: 原始客户端类型标识。
Returns:
str: 去除首尾空白后的客户端类型标识。
Raises:
ValueError: 当客户端类型为空时抛出。
"""
normalized_client_type = str(client_type or "").strip()
if not normalized_client_type:
raise ValueError("client_type 不能为空")
return normalized_client_type
def register_provider(self, registration: ClientProviderRegistration) -> None:
"""注册单个客户端类型。
Args:
registration: Provider 注册信息。
Raises:
ValueError: 当客户端类型冲突时抛出。
"""
client_type = self._normalize_client_type(registration.client_type)
existing = self.client_registry.get(client_type)
if existing is not None and existing.owner_plugin_id != registration.owner_plugin_id:
raise ValueError(
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
)
self.client_registry[client_type] = ClientProviderRegistration(
client_type=client_type,
factory=registration.factory,
owner_plugin_id=registration.owner_plugin_id,
version=registration.version,
description=registration.description,
builtin=registration.builtin,
)
if registration.owner_plugin_id:
self._owner_client_types.setdefault(registration.owner_plugin_id, set()).add(client_type)
self.clear_client_instance_cache_by_client_type(client_type)
def validate_plugin_provider_replacement(self, plugin_id: str, client_types: List[str]) -> None:
"""校验插件 Provider 替换是否会造成运行时冲突。
Args:
plugin_id: 目标插件 ID。
client_types: 插件即将注册的客户端类型列表。
Raises:
ValueError: 当客户端类型为空、重复或与其他 owner 冲突时抛出。
"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
raise ValueError("plugin_id 不能为空")
normalized_client_types = [self._normalize_client_type(client_type) for client_type in client_types]
duplicate_client_types = sorted(
{
client_type
for client_type in normalized_client_types
if normalized_client_types.count(client_type) > 1
}
)
if duplicate_client_types:
raise ValueError(f"插件 {normalized_plugin_id} 重复声明 LLM Provider: {', '.join(duplicate_client_types)}")
for client_type in normalized_client_types:
existing = self.client_registry.get(client_type)
if existing is None or existing.owner_plugin_id == normalized_plugin_id:
continue
raise ValueError(
f"LLM Provider client_type 冲突: {client_type} 已由 {existing.owner_plugin_id or 'host'} 注册"
)
def replace_plugin_providers(
self,
plugin_id: str,
registrations: List[ClientProviderRegistration],
) -> None:
"""原子替换一个插件拥有的全部 Provider 注册。
Args:
plugin_id: 目标插件 ID。
registrations: 插件当前上报的 Provider 注册列表。
Raises:
ValueError: 当注册信息不合法或存在冲突时抛出。
"""
normalized_plugin_id = str(plugin_id or "").strip()
self.validate_plugin_provider_replacement(
normalized_plugin_id,
[registration.client_type for registration in registrations],
)
self.unregister_plugin_providers(normalized_plugin_id)
for registration in registrations:
self.register_provider(
ClientProviderRegistration(
client_type=registration.client_type,
factory=registration.factory,
owner_plugin_id=normalized_plugin_id,
version=registration.version,
description=registration.description,
builtin=False,
)
)
def unregister_plugin_providers(self, plugin_id: str) -> int:
"""注销一个插件拥有的全部 Provider 注册。
Args:
plugin_id: 目标插件 ID。
Returns:
int: 被注销的客户端类型数量。
"""
normalized_plugin_id = str(plugin_id or "").strip()
if not normalized_plugin_id:
return 0
client_types = self._owner_client_types.pop(normalized_plugin_id, set())
removed_count = 0
for client_type in client_types:
registration = self.client_registry.get(client_type)
if registration is None or registration.owner_plugin_id != normalized_plugin_id:
continue
self.client_registry.pop(client_type, None)
self.clear_client_instance_cache_by_client_type(client_type)
removed_count += 1
return removed_count
def clear_client_instance_cache_by_client_type(self, client_type: str) -> None:
"""清理指定客户端类型对应的客户端实例缓存。
Args:
client_type: 需要清理缓存的客户端类型。
"""
normalized_client_type = str(client_type or "").strip()
if not normalized_client_type:
return
stale_provider_names = [
provider_name
for provider_name, client in self.client_instance_cache.items()
if client.api_provider.client_type == normalized_client_type
]
for provider_name in stale_provider_names:
self.client_instance_cache.pop(provider_name, None)
def get_client_class_instance(self, api_provider: APIProvider, force_new: bool = False) -> BaseClient:
"""获取注册的 API 客户端实例。
@@ -249,15 +445,14 @@ class ClientRegistry:
# 如果强制创建新实例,直接创建不使用缓存
if force_new:
if client_class := self.client_registry.get(api_provider.client_type):
return client_class(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
if registration := self.client_registry.get(api_provider.client_type):
return registration.factory(api_provider)
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = client_class(api_provider)
if registration := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = registration.factory(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
return self.client_instance_cache[api_provider.name]

View File

@@ -0,0 +1,191 @@
from typing import Any, Dict, List
from src.config.model_configs import APIProvider
from src.llm_models.exceptions import RespParseException
from src.llm_models.model_client.base_client import (
APIResponse,
AudioTranscriptionRequest,
BaseClient,
EmbeddingRequest,
ResponseRequest,
UsageRecord,
)
from src.llm_models.request_snapshot import (
deserialize_tool_calls_snapshot,
serialize_audio_request_snapshot,
serialize_embedding_request_snapshot,
serialize_response_request_snapshot,
)
class PluginLLMClient(BaseClient):
"""通过插件 Runner RPC 调用第三方 LLM Provider 的客户端代理。"""
def __init__(
self,
api_provider: APIProvider,
supervisor: Any,
plugin_id: str,
client_type: str,
) -> None:
"""初始化插件 LLM Provider 客户端代理。
Args:
api_provider: 当前 API Provider 配置。
supervisor: 拥有目标插件的 Supervisor。
plugin_id: 目标插件 ID。
client_type: 目标客户端类型。
"""
super().__init__(api_provider)
self._supervisor = supervisor
self._plugin_id = plugin_id
self._client_type = client_type
async def get_response(self, request: ResponseRequest) -> APIResponse:
"""获取对话响应。
Args:
request: 统一响应请求对象。
Returns:
APIResponse: 统一响应对象。
Raises:
RespParseException: 插件返回内容无法转换为统一响应时抛出。
"""
if request.stream_response_handler is not None or request.async_response_parser is not None:
raise RespParseException(message="插件 LLM Provider 暂不支持 Host 侧自定义流式处理器或响应解析器")
payload = serialize_response_request_snapshot(request)
result = await self._invoke_provider("response", payload)
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
async def get_embedding(self, request: EmbeddingRequest) -> APIResponse:
"""获取文本嵌入。
Args:
request: 统一嵌入请求对象。
Returns:
APIResponse: 嵌入响应。
"""
result = await self._invoke_provider("embedding", serialize_embedding_request_snapshot(request))
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
async def get_audio_transcriptions(self, request: AudioTranscriptionRequest) -> APIResponse:
"""获取音频转录。
Args:
request: 统一音频转录请求对象。
Returns:
APIResponse: 音频转录响应。
"""
result = await self._invoke_provider("audio_transcription", serialize_audio_request_snapshot(request))
return self._build_api_response(result, request.model_info.name, request.model_info.api_provider)
def get_support_image_formats(self) -> List[str]:
"""获取支持的图片格式。
Returns:
List[str]: 插件 Provider 默认接收的图片格式列表。
"""
return ["jpeg", "jpg", "png", "webp"]
def _build_api_provider_snapshot(self) -> Dict[str, Any]:
"""构建可传给插件 Provider 的 API Provider 配置快照。
Returns:
Dict[str, Any]: 包含认证信息的 API Provider 配置字典。
"""
return self.api_provider.model_dump(mode="json")
async def _invoke_provider(self, operation: str, request_payload: Dict[str, Any]) -> Dict[str, Any]:
"""调用插件 Provider。
Args:
operation: 请求操作类型。
request_payload: 已序列化的内部请求。
Returns:
Dict[str, Any]: 插件返回的响应字典。
Raises:
RespParseException: 插件调用失败或返回格式不合法时抛出。
"""
try:
response = await self._supervisor.invoke_llm_provider(
plugin_id=self._plugin_id,
client_type=self._client_type,
operation=operation,
request={
**request_payload,
"api_provider": self._build_api_provider_snapshot(),
},
timeout_ms=max(1000, int(self.api_provider.timeout) * 1000),
)
except Exception as exc:
raise RespParseException(message=f"插件 LLM Provider RPC 调用失败: {exc}") from exc
if response.error:
raise RespParseException(message=str(response.error.get("message", "插件 LLM Provider 调用失败")))
payload = response.payload if isinstance(response.payload, dict) else {}
success = bool(payload.get("success", False))
result = payload.get("result")
if not success:
raise RespParseException(message=str(result or "插件 LLM Provider 返回失败"))
if not isinstance(result, dict):
raise RespParseException(message="插件 LLM Provider 返回值必须是字典")
return result
@staticmethod
def _build_usage_record(raw_usage: Any, model_name: str, provider_name: str) -> UsageRecord | None:
"""从插件返回值恢复使用量记录。
Args:
raw_usage: 插件返回的使用量字段。
model_name: 当前模型名称。
provider_name: 当前 Provider 名称。
Returns:
UsageRecord | None: 可挂载到 APIResponse 的使用量记录;缺失时返回 ``None``。
"""
if not isinstance(raw_usage, dict):
return None
return UsageRecord(
model_name=str(raw_usage.get("model_name") or model_name),
provider_name=str(raw_usage.get("provider_name") or provider_name),
prompt_tokens=int(raw_usage.get("prompt_tokens") or 0),
completion_tokens=int(raw_usage.get("completion_tokens") or 0),
total_tokens=int(raw_usage.get("total_tokens") or 0),
prompt_cache_hit_tokens=int(raw_usage.get("prompt_cache_hit_tokens") or 0),
prompt_cache_miss_tokens=int(raw_usage.get("prompt_cache_miss_tokens") or 0),
)
@staticmethod
def _build_api_response(result: Dict[str, Any], model_name: str, provider_name: str) -> APIResponse:
"""从插件返回值恢复统一 APIResponse。
Args:
result: 插件返回的响应字典。
model_name: 当前模型名称。
provider_name: 当前 Provider 名称。
Returns:
APIResponse: 统一响应对象。
"""
raw_embedding = result.get("embedding")
embedding = [float(item) for item in raw_embedding] if isinstance(raw_embedding, list) else None
content = result.get("content")
if not isinstance(content, str):
content = result.get("response")
reasoning_content = result.get("reasoning_content")
if not isinstance(reasoning_content, str):
reasoning_content = result.get("reasoning")
return APIResponse(
content=content if isinstance(content, str) else None,
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
tool_calls=deserialize_tool_calls_snapshot(result.get("tool_calls")) or None,
embedding=embedding,
usage=PluginLLMClient._build_usage_record(result.get("usage"), model_name, provider_name),
raw_data=result.get("raw_data", result),
)

View File

@@ -1072,3 +1072,65 @@ class TempMethodsLLMUtils:
if provider.name == provider_name:
return provider
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
class LLMRequest(LLMOrchestrator):
"""兼容旧调用方的 LLM 请求入口。
新代码应优先使用 ``LLMOrchestrator`` 或服务层 ``LLMServiceClient``
该类保留旧版 ``model_set=TaskConfig`` 的构造方式,并在运行时解析
对应的最新任务配置名称。
"""
def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
"""初始化旧版 LLM 请求对象。
Args:
model_set: 旧调用方传入的任务配置对象。
request_type: 当前请求的业务类型标识。
"""
self._task_config_name = self._resolve_task_config_name(model_set)
super().__init__(task_name=self._task_config_name, request_type=request_type)
@staticmethod
def _build_task_config_signature(task_config: TaskConfig) -> Tuple[Any, ...]:
"""构造任务配置签名。
Args:
task_config: 任务配置对象。
Returns:
Tuple[Any, ...]: 可用于匹配热重载前后任务配置的签名。
"""
return (
tuple(task_config.model_list),
task_config.max_tokens,
task_config.temperature,
task_config.slow_threshold,
task_config.selection_strategy,
)
@classmethod
def _resolve_task_config_name(cls, model_set: TaskConfig) -> str:
"""根据旧版 TaskConfig 对象解析任务配置名称。
Args:
model_set: 旧调用方传入的任务配置对象。
Returns:
str: 对应 ``model_task_config`` 下的字段名。
Raises:
ValueError: 未能找到匹配任务配置时抛出。
"""
target_signature = cls._build_task_config_signature(model_set)
model_task_config = config_manager.get_model_config().model_task_config
for attr_name in dir(model_task_config):
if attr_name.startswith("_"):
continue
attr_value = getattr(model_task_config, attr_name)
if not isinstance(attr_value, TaskConfig):
continue
if attr_value is model_set or cls._build_task_config_signature(attr_value) == target_signature:
return attr_name
raise ValueError("无法根据旧版 model_set 解析任务配置名称")

View File

@@ -10,6 +10,8 @@ import sys
from src.common.logger import get_logger
from src.config.config import config_manager, global_config
from src.llm_models.model_client.base_client import ClientProviderRegistration, client_registry
from src.llm_models.model_client.plugin_client import PluginLLMClient
from src.platform_io import DriverKind, InboundMessageEnvelope, RouteBinding, RouteKey, get_platform_io_manager
from src.platform_io.drivers import PluginPlatformDriver
from src.platform_io.route_key_factory import RouteKeyFactory
@@ -30,6 +32,7 @@ from src.plugin_runtime.protocol.envelope import (
HealthPayload,
InspectPluginConfigPayload,
InspectPluginConfigResultPayload,
LLMProviderInvokePayload,
MessageGatewayStateUpdatePayload,
MessageGatewayStateUpdateResultPayload,
PROTOCOL_VERSION,
@@ -417,6 +420,38 @@ class PluginRunnerSupervisor:
timeout_ms=timeout_ms,
)
async def invoke_llm_provider(
self,
plugin_id: str,
client_type: str,
operation: str,
request: Optional[Dict[str, Any]] = None,
timeout_ms: int = 30000,
) -> Envelope:
"""调用插件声明的 LLM Provider。
Args:
plugin_id: 目标插件 ID。
client_type: 目标客户端类型。
operation: 请求操作类型。
request: 已序列化的 LLM 请求。
timeout_ms: RPC 超时时间,单位毫秒。
Returns:
Envelope: Runner 返回的响应信封。
"""
payload = LLMProviderInvokePayload(
client_type=client_type,
operation=operation,
request=request or {},
)
return await self._rpc_server.send_request(
"plugin.invoke_llm_provider",
plugin_id=plugin_id,
payload=payload.model_dump(),
timeout_ms=timeout_ms,
)
async def invoke_api(
self,
plugin_id: str,
@@ -779,6 +814,22 @@ class PluginRunnerSupervisor:
component_declarations = [component.model_dump() for component in payload.components]
runtime_components, api_components = self._split_component_declarations(component_declarations)
try:
client_registry.validate_plugin_provider_replacement(
payload.plugin_id,
[provider.client_type for provider in payload.llm_providers],
)
except Exception as exc:
logger.error(f"插件 {payload.plugin_id} LLM Provider 注册校验失败: {exc}")
return envelope.make_error_response(
ErrorCode.E_BAD_PAYLOAD.value,
str(exc),
details={
"plugin_id": payload.plugin_id,
"llm_provider_count": len(payload.llm_providers),
},
)
try:
registered_count = self._component_registry.register_plugin_components(
payload.plugin_id,
@@ -798,6 +849,24 @@ class PluginRunnerSupervisor:
self._api_registry.remove_apis_by_plugin(payload.plugin_id)
registered_api_count = self._api_registry.register_plugin_apis(payload.plugin_id, api_components)
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
client_registry.replace_plugin_providers(
payload.plugin_id,
[
ClientProviderRegistration(
client_type=provider.client_type,
factory=lambda api_provider, provider_client_type=provider.client_type: PluginLLMClient(
api_provider=api_provider,
supervisor=self,
plugin_id=payload.plugin_id,
client_type=provider_client_type,
),
owner_plugin_id=payload.plugin_id,
version=provider.version,
description=provider.description or provider.name,
)
for provider in payload.llm_providers
],
)
self._registered_plugins[payload.plugin_id] = payload
self._message_gateway_states[payload.plugin_id] = {}
@@ -810,6 +879,7 @@ class PluginRunnerSupervisor:
"message_gateways": len(
self._component_registry.get_message_gateways(plugin_id=payload.plugin_id, enabled_only=False)
),
"llm_providers": len(payload.llm_providers),
}
)
@@ -829,6 +899,7 @@ class PluginRunnerSupervisor:
removed_components = self._component_registry.remove_components_by_plugin(payload.plugin_id)
removed_apis = self._api_registry.remove_apis_by_plugin(payload.plugin_id)
removed_llm_providers = client_registry.unregister_plugin_providers(payload.plugin_id)
self._authorization.revoke_permission_token(payload.plugin_id)
removed_registration = self._registered_plugins.pop(payload.plugin_id, None) is not None
await self._unregister_all_message_gateway_drivers_for_plugin(payload.plugin_id)
@@ -841,6 +912,7 @@ class PluginRunnerSupervisor:
"reason": payload.reason,
"removed_components": removed_components,
"removed_apis": removed_apis,
"removed_llm_providers": removed_llm_providers,
"removed_registration": removed_registration,
}
)
@@ -1505,6 +1577,8 @@ class PluginRunnerSupervisor:
def _clear_runner_state(self) -> None:
"""清理当前 Runner 对应的 Host 侧注册状态。"""
for plugin_id in list(self._registered_plugins):
client_registry.unregister_plugin_providers(plugin_id)
self._authorization.clear()
self._api_registry.clear()
self._component_registry.clear()

View File

@@ -148,6 +148,35 @@ class PluginRuntimeManager(
validator = ManifestValidator(validate_python_package_dependencies=False)
return validator.build_plugin_dependency_map(plugin_dirs)
@classmethod
def _discover_llm_provider_conflicts(cls, plugin_dirs: Iterable[Path]) -> Dict[str, str]:
"""扫描插件 Manifest发现 LLM Provider client_type 冲突。
Args:
plugin_dirs: 需要扫描的插件根目录集合。
Returns:
Dict[str, str]: 需要阻止加载的插件 ID 与原因映射。
"""
validator = ManifestValidator(validate_python_package_dependencies=False)
provider_owners: Dict[str, List[str]] = {}
for _plugin_path, manifest in validator.iter_plugin_manifests(plugin_dirs, require_entrypoint=True):
for client_type in manifest.llm_provider_client_types:
provider_owners.setdefault(client_type, []).append(manifest.id)
blocked_reasons: Dict[str, str] = {}
for client_type, plugin_ids in provider_owners.items():
unique_plugin_ids = sorted(set(plugin_ids))
if len(unique_plugin_ids) <= 1:
continue
reason = (
f"LLM Provider client_type 冲突: {client_type} 被以下插件重复声明: "
f"{', '.join(unique_plugin_ids)}"
)
for plugin_id in unique_plugin_ids:
blocked_reasons[plugin_id] = reason
return blocked_reasons
@classmethod
def _build_group_start_order(
cls,
@@ -271,7 +300,11 @@ class PluginRuntimeManager(
"""
result = await self._plugin_dependency_pipeline.execute(plugin_dirs)
changed_plugin_ids = self._set_blocked_plugin_reasons(result.blocked_plugin_reasons)
blocked_plugin_reasons = {
**result.blocked_plugin_reasons,
**self._discover_llm_provider_conflicts(plugin_dirs),
}
changed_plugin_ids = self._set_blocked_plugin_reasons(blocked_plugin_reasons)
return DependencySyncState(
blocked_changed_plugin_ids=changed_plugin_ids,
environment_changed=result.environment_changed,

View File

@@ -199,6 +199,21 @@ class ComponentDeclaration(BaseModel):
"""组件元数据"""
class LLMProviderDeclaration(BaseModel):
"""单个 LLM Provider 声明。"""
client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type")
"""客户端类型标识。"""
name: str = Field(default="", description="Provider 展示名称")
"""Provider 展示名称。"""
description: str = Field(default="", description="Provider 描述")
"""Provider 描述。"""
version: str = Field(default="1.0.0", description="Provider 实现版本")
"""Provider 实现版本。"""
metadata: Dict[str, Any] = Field(default_factory=dict, description="Provider 元数据")
"""Provider 元数据。"""
class RegisterPluginPayload(BaseModel):
"""插件组件注册请求载荷。
@@ -212,6 +227,8 @@ class RegisterPluginPayload(BaseModel):
"""插件版本"""
components: List[ComponentDeclaration] = Field(default_factory=list, description="组件列表")
"""组件列表"""
llm_providers: List[LLMProviderDeclaration] = Field(default_factory=list, description="LLM Provider 声明列表")
"""LLM Provider 声明列表。"""
capabilities_required: List[str] = Field(default_factory=list, description="所需能力列表")
"""所需能力列表"""
dependencies: List[str] = Field(default_factory=list, description="插件级依赖插件 ID 列表")
@@ -254,6 +271,17 @@ class InvokeResultPayload(BaseModel):
"""返回值"""
class LLMProviderInvokePayload(BaseModel):
"""plugin.invoke_llm_provider 请求 payload。"""
client_type: str = Field(description="目标 LLM Provider 客户端类型")
"""目标 LLM Provider 客户端类型。"""
operation: str = Field(description="请求操作类型")
"""请求操作类型,如 response、embedding、audio_transcription。"""
request: Dict[str, Any] = Field(default_factory=dict, description="已序列化的 LLM 请求")
"""已序列化的 LLM 请求。"""
# ====== 能力调用消息 ======
class CapabilityRequestPayload(BaseModel):
"""cap.* 请求 payload插件 -> Host 能力调用)"""

View File

@@ -7,7 +7,7 @@
from functools import lru_cache
from importlib import metadata as importlib_metadata
from pathlib import Path
from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Tuple, Union
from typing import Annotated, Any, Dict, Iterable, List, Literal, Optional, Set, Tuple, Union
import json
import re
@@ -453,6 +453,38 @@ class PythonPackageDependencyDefinition(_StrictManifestModel):
return value
class LLMProviderManifestDeclaration(_StrictManifestModel):
"""插件 Manifest 中声明的 LLM Provider。"""
client_type: str = Field(description="客户端类型标识,对应模型配置中的 api_providers[].client_type")
"""客户端类型标识。"""
name: str = Field(default="", description="Provider 展示名称")
"""Provider 展示名称。"""
description: str = Field(default="", description="Provider 描述")
"""Provider 描述。"""
version: str = Field(default="1.0.0", description="Provider 实现版本")
"""Provider 实现版本。"""
@field_validator("client_type")
@classmethod
def _validate_client_type(cls, value: str) -> str:
"""校验客户端类型标识。
Args:
value: 原始客户端类型标识。
Returns:
str: 合法的客户端类型标识。
Raises:
ValueError: 当客户端类型为空时抛出。
"""
normalized_value = str(value or "").strip()
if not normalized_value:
raise ValueError("client_type 不能为空")
return normalized_value
ManifestDependencyDefinition = Annotated[
Union[PluginDependencyDefinition, PythonPackageDependencyDefinition],
Field(discriminator="type"),
@@ -472,6 +504,10 @@ class PluginManifest(_StrictManifestModel):
host_application: ManifestVersionRange = Field(description="Host 兼容区间")
sdk: ManifestVersionRange = Field(description="SDK 兼容区间")
dependencies: List[ManifestDependencyDefinition] = Field(default_factory=list, description="依赖声明")
llm_providers: List[LLMProviderManifestDeclaration] = Field(
default_factory=list,
description="插件静态声明的 LLM Provider 列表",
)
capabilities: List[str] = Field(description="插件声明的能力请求")
i18n: ManifestI18n = Field(description="国际化配置")
id: str = Field(description="稳定插件 ID")
@@ -567,6 +603,23 @@ class PluginManifest(_StrictManifestModel):
return self
@model_validator(mode="after")
def _validate_llm_providers(self) -> "PluginManifest":
"""校验 LLM Provider 静态声明集合。
Returns:
PluginManifest: 当前对象本身。
Raises:
ValueError: 当同一 Manifest 内重复声明 client_type 时抛出。
"""
client_types: Set[str] = set()
for provider in self.llm_providers:
if provider.client_type in client_types:
raise ValueError(f"存在重复的 LLM Provider 声明: {provider.client_type}")
client_types.add(provider.client_type)
return self
@property
def plugin_dependencies(self) -> List[PluginDependencyDefinition]:
"""返回插件级依赖列表。
@@ -598,6 +651,15 @@ class PluginManifest(_StrictManifestModel):
"""
return [dependency.id for dependency in self.plugin_dependencies]
@property
def llm_provider_client_types(self) -> List[str]:
"""返回 Manifest 静态声明的 LLM Provider client_type 列表。
Returns:
List[str]: 当前插件声明的 LLM Provider client_type。
"""
return [provider.client_type for provider in self.llm_providers]
class ManifestValidator:
"""严格的插件 Manifest v2 校验器。"""

View File

@@ -54,6 +54,7 @@ class PluginMeta:
self.capabilities_required = list(manifest.capabilities)
self.dependencies: List[str] = list(manifest.plugin_dependency_ids)
self.component_handlers: Dict[str, str] = {}
self.llm_provider_handlers: Dict[str, str] = {}
class PluginLoader:

View File

@@ -48,6 +48,8 @@ from src.plugin_runtime.protocol.envelope import (
InspectPluginConfigResultPayload,
InvokePayload,
InvokeResultPayload,
LLMProviderDeclaration,
LLMProviderInvokePayload,
RegisterPluginPayload,
ReloadPluginPayload,
ReloadPluginResultPayload,
@@ -891,6 +893,7 @@ class PluginRunner:
self._rpc_client.register_method("plugin.invoke_api", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_tool", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_message_gateway", self._handle_invoke)
self._rpc_client.register_method("plugin.invoke_llm_provider", self._handle_llm_provider_invoke)
self._rpc_client.register_method("plugin.emit_event", self._handle_event_invoke)
self._rpc_client.register_method("plugin.invoke_hook", self._handle_hook_invoke)
self._rpc_client.register_method("plugin.health", self._handle_health)
@@ -980,6 +983,7 @@ class PluginRunner:
"""
# 收集插件组件声明
components: List[ComponentDeclaration] = []
llm_providers: List[LLMProviderDeclaration] = []
config_reload_subscriptions: List[str] = []
instance = meta.instance
@@ -1016,11 +1020,43 @@ class PluginRunner:
)
if hasattr(instance, "get_config_reload_subscriptions"):
config_reload_subscriptions = list(instance.get_config_reload_subscriptions())
if hasattr(instance, "get_llm_providers"):
meta.llm_provider_handlers.clear()
for provider_info in instance.get_llm_providers():
if not isinstance(provider_info, dict):
continue
client_type = str(provider_info.get("client_type", "") or "").strip()
raw_metadata = provider_info.get("metadata", {})
provider_metadata = raw_metadata if isinstance(raw_metadata, dict) else {}
if client_type:
handler_name = str(provider_metadata.get("handler_name", client_type) or client_type).strip()
meta.llm_provider_handlers[client_type] = handler_name or client_type
llm_providers.append(
LLMProviderDeclaration(
client_type=client_type,
name=str(provider_info.get("name", "") or "").strip(),
description=str(provider_info.get("description", "") or "").strip(),
version=str(provider_info.get("version", "1.0.0") or "1.0.0").strip() or "1.0.0",
metadata=provider_metadata,
)
)
declared_client_types = sorted(meta.manifest.llm_provider_client_types)
registered_client_types = sorted(provider.client_type for provider in llm_providers)
if declared_client_types != registered_client_types:
logger.error(
f"插件 {meta.plugin_id} LLM Provider 声明不一致: "
f"manifest={declared_client_types}, code={registered_client_types}"
)
return False
reg_payload = RegisterPluginPayload(
plugin_id=meta.plugin_id,
plugin_version=meta.version,
components=components,
llm_providers=llm_providers,
capabilities_required=meta.capabilities_required,
dependencies=meta.dependencies,
config_reload_subscriptions=config_reload_subscriptions,
@@ -1629,6 +1665,50 @@ class PluginRunner:
resp_payload = InvokeResultPayload(success=False, result=str(e))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_llm_provider_invoke(self, envelope: Envelope) -> Envelope:
"""处理 LLM Provider 调用请求。
Args:
envelope: RPC 请求信封。
Returns:
Envelope: 标准化后的 Provider 调用结果。
"""
try:
invoke = LLMProviderInvokePayload.model_validate(envelope.payload)
except Exception as exc:
return envelope.make_error_response(ErrorCode.E_BAD_PAYLOAD.value, str(exc))
plugin_id = envelope.plugin_id
meta = self._loader.get_plugin(plugin_id)
if meta is None:
return envelope.make_error_response(
ErrorCode.E_PLUGIN_NOT_FOUND.value,
f"插件 {plugin_id} 未加载",
)
handler_name = meta.llm_provider_handlers.get(invoke.client_type, "")
handler_method = getattr(meta.instance, handler_name, None) if handler_name else None
if handler_method is None or not callable(handler_method):
return envelope.make_error_response(
ErrorCode.E_METHOD_NOT_ALLOWED.value,
f"插件 {plugin_id} 未注册 LLM Provider: {invoke.client_type}",
)
try:
result = handler_method(operation=invoke.operation, request=invoke.request)
if inspect.isawaitable(result):
result = await result
resp_payload = InvokeResultPayload(success=True, result=result)
return envelope.make_response(payload=resp_payload.model_dump())
except Exception as exc:
logger.error(
f"插件 {plugin_id} LLM Provider {invoke.client_type} 执行异常: {exc}",
exc_info=True,
)
resp_payload = InvokeResultPayload(success=False, result=str(exc))
return envelope.make_response(payload=resp_payload.model_dump())
async def _handle_event_invoke(self, envelope: Envelope) -> Envelope:
"""处理 EventHandler 调用请求